mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
4 Commits
v0.11.0rc6
...
prune-samp
Author | SHA1 | Date | |
---|---|---|---|
ce24bc8a9e | |||
2c4101068c | |||
91d7dedc06 | |||
806a377171 |
@ -304,7 +304,6 @@ steps:
|
||||
- tests/conftest.py
|
||||
commands:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
@ -10,13 +10,6 @@ from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
@ -27,7 +20,7 @@ MM_BEAM_WIDTHS = [2]
|
||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.skip(reason="Fails on V1") # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
|
@ -9,13 +9,6 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
|
||||
|
||||
# We also test with llama because it has generation_config to specify EOS
|
||||
# (past regression).
|
||||
MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"]
|
||||
@ -31,7 +24,7 @@ def test_ignore_eos(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, enforce_eager=True, dtype=dtype) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
||||
|
@ -11,19 +11,10 @@ from ..conftest import VllmRunner
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module is V0 only since it uses dtype=float, so
|
||||
set VLLM_USE_V1=0 for all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["float"]) # needed for comparing logprobs with HF
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("max_num_batched_tokens", [4])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_get_prompt_logprobs(
|
||||
@ -31,19 +22,11 @@ def test_get_prompt_logprobs(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_top_logprobs: int,
|
||||
detokenize: bool,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
max_tokens = 5
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
@ -55,9 +38,8 @@ def test_get_prompt_logprobs(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_logprobs=num_top_logprobs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
@ -140,7 +122,9 @@ def test_get_prompt_logprobs(
|
||||
|
||||
|
||||
def test_max_logprobs():
|
||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||
runner = VllmRunner("facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enforce_eager=True)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
@ -151,24 +135,16 @@ def test_max_logprobs():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("max_num_batched_tokens", [4])
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
|
||||
def test_none_logprobs(vllm_runner, model, max_num_batched_tokens: int,
|
||||
detokenize: bool, example_prompts):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
max_tokens = 5
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
enforce_eager=True,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
|
@ -7,18 +7,11 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(monkeypatch):
|
||||
"""Only run on vLLM v1."""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
|
||||
def _generate(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
@ -57,7 +50,7 @@ class TestOneTokenBadWord:
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_one_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
with vllm_runner(self.MODEL, enforce_eager=True) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[0] == self.target_token_id
|
||||
|
||||
@ -104,7 +97,7 @@ class TestTwoTokenBadWord:
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
with vllm_runner(self.MODEL, enforce_eager=True) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
|
@ -8,12 +8,6 @@ from vllm import SamplingParams
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
@ -26,7 +20,9 @@ def test_ranks(
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_logprobs=num_top_logprobs) as vllm_model:
|
||||
|
||||
## Test greedy logprobs ranks
|
||||
|
@ -1,769 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import Counter, is_pin_memory_available
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This file tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, fake_logits: torch.Tensor):
|
||||
super().__init__()
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, VOCAB_SIZE),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
return input_tensor, fake_logits, sampler
|
||||
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def _do_sample(
|
||||
batch_size: int,
|
||||
input_tensor: torch.Tensor,
|
||||
sampler: MockLogitsSampler,
|
||||
sampling_params: SamplingParams,
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_id_counter = Counter(start=random.randint(0, 100))
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
*,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.all_stop_token_ids.add(eos_token_id)
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData.from_seqs(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
return seq_data
|
||||
|
||||
def generate_test_case():
|
||||
# generate multiple seq groups but limit total batch size
|
||||
batch_size = random.randint(1, 128)
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list: list[SequenceGroupMetadata] = []
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
min_tokens = random.randint(0, 50)
|
||||
num_stop_tokens = random.randint(0, 8)
|
||||
if num_stop_tokens > 0:
|
||||
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
|
||||
k=num_stop_tokens)
|
||||
else:
|
||||
stop_token_ids = None
|
||||
|
||||
sampling_params = create_sampling_params(
|
||||
min_tokens=min_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
seq_data: dict[int, SequenceData] = {}
|
||||
seq_group_penalization: list[bool] = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
|
||||
expected_penalization.extend(seq_group_penalization)
|
||||
sequence_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{batch_size}",
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={},
|
||||
))
|
||||
batch_size -= num_seqs
|
||||
|
||||
return {
|
||||
"expected_penalization": expected_penalization,
|
||||
"seq_group_metadata_list": sequence_metadata_list,
|
||||
}
|
||||
|
||||
# define some explicit test cases for edge case behavior
|
||||
prompt_without_penalization = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(0),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization = {
|
||||
"expected_penalization": [True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=100),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
2, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
if seed == 0:
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
|
||||
def run_test_case(*, expected_penalization: list[bool],
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata]):
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
seq_lens: list[int] = []
|
||||
sampling_params_per_row: list[SamplingParams] = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
sampling_params = sgm.sampling_params
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
assert sgm.sampling_params is not None
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else [1] * batch_size,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = sampling_params.all_stop_token_ids
|
||||
|
||||
if should_penalize:
|
||||
for token_id in tokens_to_check:
|
||||
assert fake_logits[logits_idx, token_id] == -float(
|
||||
'inf'
|
||||
), f"Expected token {token_id} for logits row {logits_idx}"
|
||||
" to be penalized"
|
||||
# no other tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] == -float('inf')) == len(
|
||||
tokens_to_check
|
||||
), f"Expected only {len(tokens_to_check)} to be penalized"
|
||||
else:
|
||||
# no tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] ==
|
||||
-float('inf')) == 0, "No tokens should have been penalized"
|
||||
|
||||
for test_case in test_cases:
|
||||
run_test_case(**test_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_mixed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
expected_tokens: list[Optional[list[int]]] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[list[int]] = None
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
|
||||
elif sampling_type in (1, 2):
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10),
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
if sampling_type == 2:
|
||||
sampling_params.seed = random.randint(0, 10000)
|
||||
else:
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
generators: dict[str, torch.Generator] = {}
|
||||
|
||||
def test_sampling():
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
generators=generators)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.seed is not None
|
||||
and expected_tokens[i] is None):
|
||||
# Record seeded random result to compare with results of
|
||||
# second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
]
|
||||
continue
|
||||
|
||||
expected_tokens_item = expected_tokens[i]
|
||||
assert expected_tokens_item is not None
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.temperature == 0
|
||||
or metadata.sampling_params.seed is not None):
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens_item[n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit
|
||||
# tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens_item
|
||||
|
||||
# Test batch
|
||||
test_sampling()
|
||||
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, seq_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with
|
||||
# the corresponding sample in the pre-shuffled batch
|
||||
test_sampling()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
top_k = random.randint(100, 500)
|
||||
top_p = random.random() * 0.1
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.normal(0,
|
||||
5,
|
||||
size=(batch_size, vocab_size),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True)
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
generation_model.config = MockConfig() # needed by the following method
|
||||
generation_model._prepare_special_tokens(generation_config, device=device)
|
||||
processors = generation_model._get_logits_processor(generation_config,
|
||||
None,
|
||||
None,
|
||||
None, [],
|
||||
device=device)
|
||||
assert len(processors) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
sample_probs = None
|
||||
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
# top-k and top-p is only calculated when flashinfer kernel is not available
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
|
||||
patch("vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", None):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
assert sample_probs is not None
|
||||
|
||||
hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
pytest.skip("After FlashInfer 0.2.3, sampling will never fail")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
def failing_flashinfer_sampling(*_args, **_kwargs):
|
||||
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
|
||||
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert sampler_output == fallback_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
def test_sampling_params(sampling_params: list[SamplingParams]):
|
||||
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(2):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
fake_logits = torch.full((2, vocab_size),
|
||||
1e-2,
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
|
||||
fake_logits[:, 5] = 1.1e-2
|
||||
fake_logits[:, 1] = 1.2e-2
|
||||
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
generated_tokens = []
|
||||
for output in sampler_output:
|
||||
generated_tokens.append(output.samples[0].output_token)
|
||||
|
||||
return generated_tokens
|
||||
|
||||
# one configuration is greedy with repetition_penalty
|
||||
sampling_params_rep = SamplingParams(
|
||||
temperature=0.0,
|
||||
repetition_penalty=2.0,
|
||||
)
|
||||
|
||||
# other configuration is sampling w/o repetition_penalty
|
||||
sampling_params_sample = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_k=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
tokens1 = test_sampling_params(
|
||||
[sampling_params_rep, sampling_params_sample])
|
||||
|
||||
tokens2 = test_sampling_params(
|
||||
[sampling_params_sample, sampling_params_rep])
|
||||
|
||||
assert tokens1[0] == tokens2[1]
|
||||
assert tokens1[1] == tokens2[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_include_gpu_probs_tensor(device: str):
|
||||
set_random_seed(42)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampler.include_gpu_probs_tensor = True
|
||||
sampler.should_modify_greedy_probs_inplace = False
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
mock_inplace = Mock()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
|
||||
mock_inplace):
|
||||
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
mock_inplace.assert_not_called()
|
||||
|
||||
assert sampler_output.sampled_token_probs is not None
|
||||
assert sampler_output.logprobs is not None
|
||||
assert sampler_output.sampled_token_ids is not None
|
@ -1,86 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Verify that seeded random sampling is deterministic.
|
||||
|
||||
Run `pytest tests/samplers/test_seeded_generate.py`.
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
from itertools import combinations
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
RANDOM_SEEDS = list(range(5))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner, monkeypatch):
|
||||
# This file relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(MODEL, dtype="half") as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_random_sample_with_seed(
|
||||
vllm_model,
|
||||
example_prompts,
|
||||
seed: int,
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=3.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
presence_penalty=random.randint(0, 1),
|
||||
max_tokens=8,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_1.seed = 100
|
||||
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_2.seed = 200
|
||||
|
||||
llm = vllm_model.llm
|
||||
|
||||
for prompt in example_prompts:
|
||||
for params in (
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
):
|
||||
llm._add_request(prompt, params=params)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
all_outputs = [[out.token_ids for out in output.outputs]
|
||||
for output in results]
|
||||
|
||||
for i in range(0, len(example_prompts), 6):
|
||||
outputs = all_outputs[i:i + 6]
|
||||
|
||||
# verify all non-seeded requests differ
|
||||
for output_a, output_b in combinations(
|
||||
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
||||
2,
|
||||
):
|
||||
assert output_a != output_b
|
||||
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
|
||||
# verify generations within the same parallel sampling group differ
|
||||
for output in outputs:
|
||||
for sub_output_a, sub_output_b in combinations(output, 2):
|
||||
assert sub_output_a != sub_output_b
|
Reference in New Issue
Block a user