mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
feat: implement the min_tokens sampling parameter (#3124)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@ -10,6 +10,7 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
@ -25,9 +26,8 @@ class MockLogitsSampler(Sampler):
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
fake_logits = torch.full((batch_size, VOCAB_SIZE),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
@ -35,6 +35,7 @@ def _prepare_test(
|
||||
return input_tensor, fake_logits, sampler, model_runner
|
||||
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
@ -184,6 +185,225 @@ def test_sampler_all_beam(seed: int, device: str):
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_id_counter = Counter(start=random.randint(0, 100))
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
stop_token_ids=None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
sampling_params.eos_token_id = eos_token_id
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
return seq_data
|
||||
|
||||
def generate_test_case():
|
||||
# generate multiple seq groups but limit total batch size
|
||||
batch_size = random.randint(1, 128)
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list = []
|
||||
while batch_size > 0:
|
||||
# 20% chance to generate prompt seq group with single sequence
|
||||
is_prompt = random.random() < 0.2
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
min_tokens = random.randint(0, 50)
|
||||
num_stop_tokens = random.randint(0, 8)
|
||||
if num_stop_tokens > 0:
|
||||
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
|
||||
k=num_stop_tokens)
|
||||
else:
|
||||
stop_token_ids = None
|
||||
|
||||
sampling_params = create_sampling_params(
|
||||
min_tokens=min_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
seq_data = {}
|
||||
seq_group_penalization = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = random.randint(1, 100) if not is_prompt else 0
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
|
||||
expected_penalization.extend(seq_group_penalization)
|
||||
sequence_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{batch_size}",
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={},
|
||||
))
|
||||
batch_size -= num_seqs
|
||||
|
||||
return {
|
||||
"expected_penalization": expected_penalization,
|
||||
"seq_group_metadata_list": sequence_metadata_list,
|
||||
}
|
||||
|
||||
# define some explicit test cases for edge case behavior
|
||||
prompt_without_penalization = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(0),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization = {
|
||||
"expected_penalization": [True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
simple_combination = {
|
||||
"expected_penalization": [True, False, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=100),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
2, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
if seed == 0:
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
stop_penalizing_after_min_tokens,
|
||||
simple_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
|
||||
def run_test_case(*,
|
||||
expected_penalization=None,
|
||||
seq_group_metadata_list=None):
|
||||
assert expected_penalization, "Invalid test case"
|
||||
assert seq_group_metadata_list, "Invalid test case"
|
||||
|
||||
batch_size = 0
|
||||
prompt_lens = []
|
||||
sampling_params_per_seq = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
num_seqs = len(sgm.seq_data)
|
||||
batch_size += num_seqs
|
||||
sampling_params = sgm.sampling_params
|
||||
for seq_id in sgm.seq_data:
|
||||
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
|
||||
sampling_params_per_seq.append(sampling_params)
|
||||
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_seq)):
|
||||
|
||||
tokens_to_check = [sampling_params.eos_token_id]
|
||||
if sampling_params.stop_token_ids:
|
||||
tokens_to_check.extend(sampling_params.stop_token_ids)
|
||||
tokens_to_check = set(tokens_to_check)
|
||||
|
||||
if should_penalize:
|
||||
for token_id in tokens_to_check:
|
||||
assert fake_logits[logits_idx, token_id] == -float(
|
||||
'inf'
|
||||
), f"Expected token {token_id} for logits row {logits_idx}"
|
||||
" to be penalized"
|
||||
# no other tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] == -float('inf')) == len(
|
||||
tokens_to_check
|
||||
), f"Expected only {len(tokens_to_check)} to be penalized"
|
||||
else:
|
||||
# no tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] ==
|
||||
-float('inf')) == 0, "No tokens should have been penalized"
|
||||
|
||||
del model_runner
|
||||
|
||||
for test_case in test_cases:
|
||||
run_test_case(**test_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_mixed(seed: int, device: str):
|
||||
|
@ -282,6 +282,9 @@ class LLMEngine:
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
# inject the eos token id into the sampling_params to support min_tokens
|
||||
# processing
|
||||
sampling_params.eos_token_id = seq.eos_token_id
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
@ -713,6 +716,21 @@ class LLMEngine:
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
@ -725,16 +743,6 @@ class LLMEngine:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
|
@ -88,6 +88,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
length_penalty: Optional[float] = 1.0
|
||||
early_stopping: Optional[bool] = False
|
||||
ignore_eos: Optional[bool] = False
|
||||
min_tokens: Optional[int] = 0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
@ -165,6 +166,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
max_tokens=self.max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||
best_of=self.best_of,
|
||||
@ -224,6 +226,7 @@ class CompletionRequest(BaseModel):
|
||||
early_stopping: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
ignore_eos: Optional[bool] = False
|
||||
min_tokens: Optional[int] = 0
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
# doc: end-completion-sampling-params
|
||||
@ -296,6 +299,7 @@ class CompletionRequest(BaseModel):
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
||||
min_tokens=self.min_tokens,
|
||||
logprobs=self.logprobs,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -36,6 +37,10 @@ class Sampler(nn.Module):
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
||||
# have not been generated yet
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||
@ -94,6 +99,42 @@ def _get_bin_counts_and_mask(
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def _apply_min_tokens_penalty(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# list of indices in logits that will be set to -inf
|
||||
logits_to_penalize = []
|
||||
start_idx = 0
|
||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||
min_tokens = sampling_params.min_tokens
|
||||
if min_tokens > 0:
|
||||
seqs_to_penalize = []
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
seq_data = sampling_metadata.seq_data[seq_id]
|
||||
if len(seq_data.output_token_ids) < min_tokens:
|
||||
seqs_to_penalize.append(i)
|
||||
|
||||
if seqs_to_penalize:
|
||||
# convert to the index into logits
|
||||
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
||||
# use set() to remove any duplicates
|
||||
token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
||||
[sampling_params.eos_token_id])
|
||||
# itertools.product pairs each seq index with every token id
|
||||
logits_to_penalize.extend(
|
||||
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
||||
|
||||
start_idx += len(seq_ids)
|
||||
|
||||
if logits_to_penalize:
|
||||
# use zip and * to group indices along each dimension
|
||||
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
||||
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
|
@ -79,6 +79,8 @@ class SamplingParams:
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
@ -113,6 +115,7 @@ class SamplingParams:
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
@ -144,6 +147,7 @@ class SamplingParams:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
@ -161,6 +165,8 @@ class SamplingParams:
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
# injected by the engine
|
||||
self.eos_token_id = None
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
@ -191,6 +197,13 @@ class SamplingParams:
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||
f"got {self.min_tokens}.")
|
||||
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
@ -272,6 +285,7 @@ class SamplingParams:
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"min_tokens={self.min_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
|
Reference in New Issue
Block a user