mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
@ -231,6 +231,31 @@ def _topk_log_softmax_kernel(
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
topk = topk_ids.shape[1]
|
||||
output = torch.empty(
|
||||
batch_size,
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size, )](
|
||||
output,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
topk_ids,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024,
|
||||
PADDED_TOPK=triton.next_power_of_2(topk),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
@ -273,21 +298,9 @@ def compute_logprobs(
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs + 1,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size, )](
|
||||
logprobs,
|
||||
logprobs = compute_topk_logprobs(
|
||||
logits,
|
||||
logits.stride(0),
|
||||
logprob_token_ids,
|
||||
num_logprobs + 1,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024,
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
|
||||
)
|
||||
|
||||
token_ranks = torch.empty(
|
||||
|
Reference in New Issue
Block a user