[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
Hyesoo Yang
2025-03-20 19:19:40 -07:00
committed by GitHub
parent 6edbfa924d
commit 47195057e9
3 changed files with 29 additions and 4 deletions

View File

@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
# top_k=10,
top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
top_k=12,
min_p=0.8,
max_tokens=24)
s = time()

View File

@ -95,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
def get_default_cache_root():
@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
}
# end-env-vars-definition

View File

@ -66,7 +66,14 @@ class TopKTopPSampler(nn.Module):
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
else:
self.forward = self.forward_native
@ -105,8 +112,19 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# TODO Placeholder for TPU optimized topk/p kernel
# logits = apply_top_k_top_p(logits, k, p)
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)