mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Adding token ranks along with logprobs (#3516)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
This commit is contained in:
49
tests/samplers/test_ranks.py
Normal file
49
tests/samplers/test_ranks.py
Normal file
@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
|
||||
|
||||
## Test greedy logprobs ranks
|
||||
vllm_sampling_params = SamplingParams(temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
|
||||
vllm_sampling_params)
|
||||
for result in vllm_results:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks = 1
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert token in logprobs
|
||||
assert logprobs[token].rank == 1
|
||||
|
||||
## Test non-greedy logprobs ranks
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
for result in res:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert logprobs[token].rank >= 1
|
@ -465,6 +465,24 @@ def _sample(
|
||||
# sampling_tensors)
|
||||
|
||||
|
||||
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
||||
where N is the no. of tokens and M is the vocab dim.
|
||||
indices (List[int]): List of chosen token indices.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
||||
Each element in the returned tensor represents the rank
|
||||
of the chosen token in the input logprob tensor.
|
||||
"""
|
||||
vals = x[range(len(x)), indices]
|
||||
return (x > vals[:, None]).long().sum(1) + 1
|
||||
|
||||
|
||||
def _get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
@ -520,6 +538,10 @@ def _get_logprobs(
|
||||
|
||||
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
||||
|
||||
batched_ranks_query_result = _get_ranks(
|
||||
logprobs[batched_logprobs_query_seq_indices],
|
||||
batched_logprobs_query_token_indices)
|
||||
|
||||
# Gather results
|
||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||
result_sample_logprobs: List[SampleLogprobs] = []
|
||||
@ -540,15 +562,20 @@ def _get_logprobs(
|
||||
for token_id in prompt_tokens[1:]:
|
||||
prompt_logprobs_dict = {
|
||||
token_id:
|
||||
batched_logprobs_query_result[query_result_idx].item()
|
||||
(batched_logprobs_query_result[query_result_idx].item(),
|
||||
batched_ranks_query_result[query_result_idx].item())
|
||||
}
|
||||
if num_logprobs > 0:
|
||||
prompt_logprobs_dict.update(
|
||||
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
|
||||
top_logprobs[sample_idx, :num_logprobs].tolist()))
|
||||
zip(
|
||||
top_token_ids[sample_idx, :num_logprobs].tolist(),
|
||||
zip(
|
||||
top_logprobs[
|
||||
sample_idx, :num_logprobs].tolist(),
|
||||
range(1, num_logprobs + 1))))
|
||||
group_prompt_logprobs.append({
|
||||
token_id: Logprob(logprob)
|
||||
for token_id, logprob in prompt_logprobs_dict.items()
|
||||
token_id: Logprob(*logprob_rank)
|
||||
for token_id, logprob_rank in prompt_logprobs_dict.items()
|
||||
})
|
||||
sample_idx += 1
|
||||
query_result_idx += 1
|
||||
@ -564,7 +591,8 @@ def _get_logprobs(
|
||||
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
||||
sample_logprobs_dict = {
|
||||
next_token_id:
|
||||
batched_logprobs_query_result[query_result_idx].item()
|
||||
(batched_logprobs_query_result[query_result_idx].item(),
|
||||
batched_ranks_query_result[query_result_idx].item())
|
||||
}
|
||||
query_result_idx += 1
|
||||
if num_logprobs > 0:
|
||||
@ -572,11 +600,13 @@ def _get_logprobs(
|
||||
zip(
|
||||
top_token_ids[sample_idx +
|
||||
parent_id, :num_logprobs].tolist(),
|
||||
top_logprobs[sample_idx +
|
||||
parent_id, :num_logprobs].tolist()))
|
||||
zip(
|
||||
top_logprobs[sample_idx +
|
||||
parent_id, :num_logprobs].tolist(),
|
||||
range(1, num_logprobs + 1))))
|
||||
group_sample_logprobs.append({
|
||||
token_id: Logprob(logprob)
|
||||
for token_id, logprob in sample_logprobs_dict.items()
|
||||
token_id: Logprob(*logprob_rank)
|
||||
for token_id, logprob_rank in sample_logprobs_dict.items()
|
||||
})
|
||||
result_sample_logprobs.append(group_sample_logprobs)
|
||||
sample_idx += len(seq_ids)
|
||||
|
@ -16,8 +16,15 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class Logprob:
|
||||
"""Infos for supporting OpenAI compatible logprobs."""
|
||||
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
||||
|
||||
Attributes:
|
||||
logprob: The logprob of chosen token
|
||||
rank: The vocab rank of chosen token (>=1)
|
||||
decoded_token: The decoded chosen token index
|
||||
"""
|
||||
logprob: float
|
||||
rank: Optional[int] = None
|
||||
decoded_token: Optional[str] = None
|
||||
|
||||
|
||||
@ -66,7 +73,7 @@ class SequenceStatus(enum.Enum):
|
||||
class RequestMetrics:
|
||||
"""Metrics associated with a request.
|
||||
|
||||
Args:
|
||||
Attributes:
|
||||
arrival_time: The time when the request arrived.
|
||||
first_scheduled_time: The time when the request was first scheduled.
|
||||
first_token_time: The time when the first token was generated.
|
||||
|
Reference in New Issue
Block a user