Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon
2025-09-24 15:29:27 +00:00
parent ad2cf805ad
commit 866eef50ca

View File

@ -231,6 +231,31 @@ def _topk_log_softmax_kernel(
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
def compute_topk_logprobs(
logits: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
topk = topk_ids.shape[1]
output = torch.empty(
batch_size,
topk,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size, )](
output,
logits,
logits.stride(0),
topk_ids,
topk,
vocab_size,
BLOCK_SIZE=1024,
PADDED_TOPK=triton.next_power_of_2(topk),
)
return output
@triton.jit
def _ranks_kernel(
output_ptr,
@ -273,21 +298,9 @@ def compute_logprobs(
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = torch.empty(
batch_size,
num_logprobs + 1,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size, )](
logprobs,
logprobs = compute_topk_logprobs(
logits,
logits.stride(0),
logprob_token_ids,
num_logprobs + 1,
vocab_size,
BLOCK_SIZE=1024,
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
)
token_ranks = torch.empty(