[V1][TPU] TPU-optimized top-p implementation (avoids scattering). (#15736)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
This commit is contained in:
Hyesoo Yang
2025-04-02 17:18:08 -07:00
committed by GitHub
parent 55acf86bf8
commit 1b84eff03a
3 changed files with 174 additions and 15 deletions

View File

@ -36,7 +36,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& echo TEST_7 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
&& echo TEST_8 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \
# TODO: This test fails because it uses RANDOM_SEED sampling

View File

@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
import math
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
probs = logits.softmax(dim=-1)
# Random top-p values between 0 and 1.
p = torch.rand((BATCH_SIZE, ))
# Set p=1 for ~50% of requests in the batch (top-p disabled).
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)
no_op_k = torch.tensor([VOCAB_SIZE])
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
k=no_op_k,
p=p)
# Verify that the masked logit's probability sums to at least p.
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
xm.mark_step()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
def test_topp_basic():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_topp_select_all():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0]))
xm.mark_step()
assert torch.allclose(logits.cpu(), result.cpu())
def test_topp_with_ties():
with torch.device(xm.xla_device()):
# Input has multiple math.log(0.3).
logits = torch.tensor(
[[math.log(0.3),
math.log(0.3),
math.log(0.3),
math.log(0.1)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([4]),
p=torch.tensor([0.2]))
xm.mark_step()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result = logits.clone().cpu()
expected_result[0, 3] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_both_topk_topp():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
# Set k=1 for the first batch.
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[0, 1] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())

View File

@ -122,23 +122,48 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
logits = apply_top_k_top_p_tpu(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_top_p_tpu(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if k is not None:
logits = apply_top_k_only(logits, k)
if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
@ -199,7 +224,7 @@ def apply_top_k_only(
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))