Fix typical acceptance sampler with correct recovered token ids (#8562)

This commit is contained in:
jiqing-feng
2024-09-24 03:32:27 +08:00
committed by GitHub
parent b05f5c9238
commit 5f7bb58427
2 changed files with 17 additions and 28 deletions

View File

@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1)
@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, device: str):
def test_get_recovered_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_replacement_token_ids` method of the
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution.
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs))
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)

View File

@ -80,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)
recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
@ -148,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
accepted_mask = candidates_prob > threshold
return accepted_mask
def _replacement_token_ids(self, target_probs):
def _get_recovered_token_ids(self, target_probs):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
The recovered token ids will fill the first unmatched token
by the target token.
Parameters
----------
@ -168,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
A tensor of shape (batch_size, k) with the recovered token
ids which are selected from target probs.
"""
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype,
device=target_probs.device)
output[:, 0] = max_indices
return output
max_indices = torch.argmax(target_probs, dim=-1)
return max_indices