[Core] Adding token ranks along with logprobs (#3516)

Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
This commit is contained in:
Swapnil Parekh
2024-03-25 13:13:10 -04:00
committed by GitHub
parent 01bfb22b41
commit 819924e749
3 changed files with 98 additions and 12 deletions

View File

@ -0,0 +1,49 @@
import pytest
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_ranks(
vllm_runner,
model,
dtype,
example_prompts,
):
max_tokens = 5
num_top_logprobs = 5
num_prompt_logprobs = 5
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
## Test greedy logprobs ranks
vllm_sampling_params = SamplingParams(temperature=0.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
vllm_sampling_params)
for result in vllm_results:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks = 1
for token, logprobs in zip(result[0], result[2]):
assert token in logprobs
assert logprobs[token].rank == 1
## Test non-greedy logprobs ranks
sampling_params = SamplingParams(temperature=1.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
for result in res:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks
for token, logprobs in zip(result[0], result[2]):
assert logprobs[token].rank >= 1

View File

@ -465,6 +465,24 @@ def _sample(
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (List[int]): List of chosen token indices.
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[range(len(x)), indices]
return (x > vals[:, None]).long().sum(1) + 1
def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
@ -520,6 +538,10 @@ def _get_logprobs(
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices],
batched_logprobs_query_token_indices)
# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = []
@ -540,15 +562,20 @@ def _get_logprobs(
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
batched_logprobs_query_result[query_result_idx].item()
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
top_logprobs[sample_idx, :num_logprobs].tolist()))
zip(
top_token_ids[sample_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
sample_idx, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_prompt_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in prompt_logprobs_dict.items()
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in prompt_logprobs_dict.items()
})
sample_idx += 1
query_result_idx += 1
@ -564,7 +591,8 @@ def _get_logprobs(
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = {
next_token_id:
batched_logprobs_query_result[query_result_idx].item()
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
query_result_idx += 1
if num_logprobs > 0:
@ -572,11 +600,13 @@ def _get_logprobs(
zip(
top_token_ids[sample_idx +
parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist()))
zip(
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_sample_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in sample_logprobs_dict.items()
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in sample_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)

View File

@ -16,8 +16,15 @@ if TYPE_CHECKING:
@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs."""
"""Infos for supporting OpenAI compatible logprobs and token ranks.
Attributes:
logprob: The logprob of chosen token
rank: The vocab rank of chosen token (>=1)
decoded_token: The decoded chosen token index
"""
logprob: float
rank: Optional[int] = None
decoded_token: Optional[str] = None
@ -66,7 +73,7 @@ class SequenceStatus(enum.Enum):
class RequestMetrics:
"""Metrics associated with a request.
Args:
Attributes:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.