mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Fix typical acceptance sampler with correct recovered token ids (#8562)
This commit is contained in:
@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
|
||||
# Next only keep the first 2 draft tokens same as the zero temperature
|
||||
# tokens. For the remaining 3 choose some other tokens. In the
|
||||
# response we will expect the first 2 tokens to be the same as the
|
||||
# draft tokens and the rest as -1
|
||||
# draft tokens and the recovered token and rest as -1
|
||||
draft_token_ids_to_replace = get_draft_token_ids(
|
||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||
draft_token_ids = torch.cat(
|
||||
@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||
assert torch.all(
|
||||
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
|
||||
assert torch.all(output_token_ids[:, -3:] == -1)
|
||||
|
||||
|
||||
@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_replacement_token_ids(seed: int, device: str):
|
||||
def test_get_recovered_token_ids(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's method for generating
|
||||
replacement token IDs.
|
||||
|
||||
This test verifies that the `_replacement_token_ids` method of the
|
||||
This test verifies that the `_get_recovered_token_ids` method of the
|
||||
TypicalAcceptanceSampler correctly identifies the token IDs to be used
|
||||
as replacements based on the target probability distribution.
|
||||
as recovered token IDs based on the target probability distribution.
|
||||
Specifically, it ensures that the method correctly identifies the
|
||||
tokens with the highest probability for each sequence in the batch.
|
||||
"""
|
||||
@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
expected_replacement_tokens = -torch.ones(
|
||||
(batch_size, k), dtype=torch.long)
|
||||
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
|
||||
dim=1)
|
||||
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
|
||||
actual_replacement_tokens = (
|
||||
typical_acceptance_sampler._replacement_token_ids(target_probs))
|
||||
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
|
||||
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
|
||||
|
@ -80,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
accepted = self._evaluate_accepted_tokens(target_probs,
|
||||
draft_token_ids)
|
||||
recovered_token_ids = self._replacement_token_ids(target_probs)
|
||||
recovered_token_ids = self._get_recovered_token_ids(target_probs)
|
||||
output_token_ids = self._create_output(accepted, recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids)
|
||||
@ -148,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
accepted_mask = candidates_prob > threshold
|
||||
return accepted_mask
|
||||
|
||||
def _replacement_token_ids(self, target_probs):
|
||||
def _get_recovered_token_ids(self, target_probs):
|
||||
"""
|
||||
Generate one replacement token ID for each sequence based on target
|
||||
probabilities. The replacement token is used as the fallback option
|
||||
if typical acceptance sampling does not accept any draft tokens for
|
||||
that particular sequence.
|
||||
|
||||
This method computes the token IDs to be replaced by selecting the
|
||||
token with the highest probability for each sequence in the first
|
||||
position. The rest of the output is filled with -1.
|
||||
The recovered token ids will fill the first unmatched token
|
||||
by the target token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -168,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A tensor of shape (batch_size, k) with the replacement
|
||||
token IDs. Only the first column is set, and the rest of the
|
||||
columns are filled with -1.
|
||||
A tensor of shape (batch_size, k) with the recovered token
|
||||
ids which are selected from target probs.
|
||||
"""
|
||||
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
|
||||
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
|
||||
dtype=self.token_id_dtype,
|
||||
device=target_probs.device)
|
||||
output[:, 0] = max_indices
|
||||
return output
|
||||
max_indices = torch.argmax(target_probs, dim=-1)
|
||||
|
||||
return max_indices
|
||||
|
Reference in New Issue
Block a user