[v1] Support allowed_token_ids in v1 Sampler (#13210)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang
2025-02-21 22:13:05 -08:00
committed by GitHub
parent 8aca27fa11
commit bb78fb318e
7 changed files with 168 additions and 19 deletions

View File

@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
output_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
)

View File

@ -57,6 +57,26 @@ def _create_logit_bias(
return res
def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros((batch_size, vocab_size),
dtype=torch.bool,
device=device)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask
def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
)
return fake_sampling_metadata
@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata.frequency_penalties = _create_penalty_tensor(
batch_size, frequency_penalty, torch.device(device))
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
_create_weighted_output_token_list(
batch_size,
VOCAB_SIZE,
)
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx]
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
len(distinct_sorted_token_ids_in_output) - 1]
if frequency_penalty > 0:
@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert non_penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert (non_penalized_token_id
not in distinct_sorted_token_ids_in_output)
assert penalized_token_id == most_frequent_token_id
elif frequency_penalty < 0:
# If `frequency_penalty` is set to < 0, it indicates
@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# in the output, while the penalized token ID is one that has not
# yet appeared.
assert non_penalized_token_id == most_frequent_token_id
assert penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert penalized_token_id not in distinct_sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert (non_penalized_token_id not in prompt_tokens and \
non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens or \
penalized_token_id in output_tokens)
assert (non_penalized_token_id not in prompt_tokens
and non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens
or penalized_token_id in output_tokens)
elif repetition_penalty < 1.0:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert (penalized_token_id not in prompt_tokens and \
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)
assert (penalized_token_id not in prompt_tokens
and penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens
or non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
def test_sampler_allowed_token_ids(device: str, batch_size: int,
num_allowed_token_ids: int):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
mask = _create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
num_allowed_token_ids=num_allowed_token_ids,
device=device,
)
sampling_metadata.allowed_token_ids_mask = mask
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
if batch_idx % 2 == 1:
assert torch.all(logits_for_req != -float("inf"))
continue
for token_id in range(VOCAB_SIZE):
start = min(batch_idx, VOCAB_SIZE - 1)
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
if token_id >= start and token_id < end:
assert logits_for_req[token_id] == -float(
"inf"), f"{batch_idx}, {token_id}"
else:
assert logits_for_req[token_id] != -float("inf")

View File

@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
for req in reqs:
if req.req_id not in req_ids_retained:
continue
@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
)
@ -242,3 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)

View File

@ -83,6 +83,19 @@ class Processor:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_allowed_token_ids(
self,
params: Union[SamplingParams, PoolingParams],
) -> None:
if not isinstance(params, SamplingParams):
return
if params.allowed_token_ids is None:
return
if not all(0 <= tid < self.model_config.vocab_size
for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id")
def process_inputs(
self,
request_id: str,
@ -100,6 +113,7 @@ class Processor:
self._validate_logprobs(params)
self._validate_lora(lora_request)
self._validate_allowed_token_ids(params)
if arrival_time is None:
arrival_time = time.time()

View File

@ -37,3 +37,7 @@ class SamplingMetadata:
min_tokens: Dict[int, Tuple[int, Set[int]]]
logit_bias: List[Optional[Dict[int, float]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]

View File

@ -47,6 +47,8 @@ class Sampler(nn.Module):
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
@ -184,11 +186,13 @@ class Sampler(nn.Module):
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits, sampling_metadata.prompt_token_ids,
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
sampling_metadata.output_token_ids,
)
return logits
def apply_min_p(
@ -226,3 +230,13 @@ class Sampler(nn.Module):
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits
def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
float("-inf"))
return logits

View File

@ -143,7 +143,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set()
# Presence penalty related data structures
@ -168,7 +168,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
@ -192,6 +192,9 @@ class InputBatch:
self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: Set[str] = set()
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
self.req_output_token_ids: List[Optional[List[int]]] = []
@ -287,6 +290,22 @@ class InputBatch:
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = True
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
@ -332,6 +351,9 @@ class InputBatch:
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
return req_index
def condense(self, empty_req_indices: List[int]) -> None:
@ -400,6 +422,11 @@ class InputBatch:
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@ -442,6 +469,13 @@ class InputBatch:
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
@ -460,6 +494,7 @@ class InputBatch:
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
)
def get_sampling_metadata(
@ -550,3 +585,7 @@ class InputBatch:
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0