mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[v1] Support allowed_token_ids in v1 Sampler (#13210)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
output_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -57,6 +57,26 @@ def _create_logit_bias(
|
||||
return res
|
||||
|
||||
|
||||
def _create_allowed_token_ids(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
num_allowed_token_ids: int,
|
||||
device: torch.device,
|
||||
) -> Optional[torch.Tensor]:
|
||||
mask: Optional[torch.Tensor] = None
|
||||
for i in range(batch_size):
|
||||
if i % 2 == 1:
|
||||
continue
|
||||
if mask is None:
|
||||
mask = torch.zeros((batch_size, vocab_size),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
start = min(i, vocab_size - 1)
|
||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||
mask[i, start:end] = True
|
||||
return mask
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
|
||||
no_penalties=True,
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
||||
batch_size, frequency_penalty, torch.device(device))
|
||||
output_token_ids, sorted_token_ids_in_output = \
|
||||
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
|
||||
_create_weighted_output_token_list(
|
||||
batch_size,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
distinct_sorted_token_ids_in_output = \
|
||||
sorted_token_ids_in_output[batch_idx]
|
||||
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
|
||||
batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
len(distinct_sorted_token_ids_in_output) - 1]
|
||||
if frequency_penalty > 0:
|
||||
@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
# non-penalized token ID is not present in the output, while the
|
||||
# most penalized token is the one that occurs most frequently in
|
||||
# the output.
|
||||
assert non_penalized_token_id \
|
||||
not in distinct_sorted_token_ids_in_output
|
||||
assert (non_penalized_token_id
|
||||
not in distinct_sorted_token_ids_in_output)
|
||||
assert penalized_token_id == most_frequent_token_id
|
||||
elif frequency_penalty < 0:
|
||||
# If `frequency_penalty` is set to < 0, it indicates
|
||||
@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
# in the output, while the penalized token ID is one that has not
|
||||
# yet appeared.
|
||||
assert non_penalized_token_id == most_frequent_token_id
|
||||
assert penalized_token_id \
|
||||
not in distinct_sorted_token_ids_in_output
|
||||
assert penalized_token_id not in distinct_sorted_token_ids_in_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
||||
# token ID has not been seen before, while the penalized token ID
|
||||
# exists either in the prompt or the output.
|
||||
assert (non_penalized_token_id not in prompt_tokens and \
|
||||
non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens or \
|
||||
penalized_token_id in output_tokens)
|
||||
assert (non_penalized_token_id not in prompt_tokens
|
||||
and non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens
|
||||
or penalized_token_id in output_tokens)
|
||||
elif repetition_penalty < 1.0:
|
||||
# If `repetition_penalty` < 1.0, verify that the penalized
|
||||
# token ID has not been seen before, while the non-penalized
|
||||
# token ID exists either in the prompt or the output.
|
||||
assert (penalized_token_id not in prompt_tokens and \
|
||||
penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens or \
|
||||
non_penalized_token_id in output_tokens)
|
||||
assert (penalized_token_id not in prompt_tokens
|
||||
and penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens
|
||||
or non_penalized_token_id in output_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
||||
1e-2)
|
||||
else:
|
||||
assert logits_for_req[token_id] == pytest.approx(1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
||||
def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
num_allowed_token_ids: int):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
output.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
mask = _create_allowed_token_ids(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
num_allowed_token_ids=num_allowed_token_ids,
|
||||
device=device,
|
||||
)
|
||||
sampling_metadata.allowed_token_ids_mask = mask
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
if batch_idx % 2 == 1:
|
||||
assert torch.all(logits_for_req != -float("inf"))
|
||||
continue
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
start = min(batch_idx, VOCAB_SIZE - 1)
|
||||
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
|
||||
if token_id >= start and token_id < end:
|
||||
assert logits_for_req[token_id] == -float(
|
||||
"inf"), f"{batch_idx}, {token_id}"
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
allowed_token_ids_mask = torch.zeros(num_reqs,
|
||||
VOCAB_SIZE,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
|
||||
req.sampling_params.min_tokens,
|
||||
req.sampling_params.all_stop_token_ids)
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
)
|
||||
|
||||
|
||||
@ -242,3 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||
if sampling_metadata.allowed_token_ids_mask:
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
|
@ -83,6 +83,19 @@ class Processor:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
def _validate_allowed_token_ids(
|
||||
self,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
) -> None:
|
||||
if not isinstance(params, SamplingParams):
|
||||
return
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not all(0 <= tid < self.model_config.vocab_size
|
||||
for tid in params.allowed_token_ids):
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id")
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -100,6 +113,7 @@ class Processor:
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_allowed_token_ids(params)
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
@ -37,3 +37,7 @@ class SamplingMetadata:
|
||||
min_tokens: Dict[int, Tuple[int, Set[int]]]
|
||||
|
||||
logit_bias: List[Optional[Dict[int, float]]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
@ -47,6 +47,8 @@ class Sampler(nn.Module):
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply allowed token ids.
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply logits bias.
|
||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
@ -184,11 +186,13 @@ class Sampler(nn.Module):
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
logits, sampling_metadata.prompt_token_ids,
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids)
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
def apply_min_p(
|
||||
@ -226,3 +230,13 @@ class Sampler(nn.Module):
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[i, token_id] += bias
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
@ -143,7 +143,7 @@ class InputBatch:
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
@ -168,7 +168,7 @@ class InputBatch:
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: Set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
@ -192,6 +192,9 @@ class InputBatch:
|
||||
|
||||
self.logit_bias: List[Optional[Dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: Set[str] = set()
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
self.req_output_token_ids: List[Optional[List[int]]] = []
|
||||
|
||||
@ -287,6 +290,22 @@ class InputBatch:
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = True
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@ -332,6 +351,9 @@ class InputBatch:
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
return req_index
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
@ -400,6 +422,11 @@ class InputBatch:
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@ -442,6 +469,13 @@ class InputBatch:
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
assert self.allowed_token_ids_mask is not None
|
||||
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
||||
self.allowed_token_ids_mask, num_reqs)
|
||||
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
@ -460,6 +494,7 @@ class InputBatch:
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
)
|
||||
|
||||
def get_sampling_metadata(
|
||||
@ -550,3 +585,7 @@ class InputBatch:
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
|
Reference in New Issue
Block a user