fix long dtype in topk sampling (#15049)

This commit is contained in:
Chujie Zheng
2025-03-19 06:57:31 +08:00
committed by GitHub
parent 72a8639b68
commit 027827cc1d

View File

@ -151,7 +151,7 @@ class Sampler(nn.Module):
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_ids = token_ids.unsqueeze(-1).to(torch.long)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.