mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
# top_p=0.6, # TODO too slow!
|
||||
# top_k=10,
|
||||
top_k=10,
|
||||
min_p=0.2,
|
||||
max_tokens=16)
|
||||
s = time()
|
||||
@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
|
||||
# Second request with different params, but for which we
|
||||
# compiled for in previous eager iteration.
|
||||
sampling_params = SamplingParams(temperature=0.1,
|
||||
top_k=12,
|
||||
min_p=0.8,
|
||||
max_tokens=24)
|
||||
s = time()
|
||||
|
@ -95,6 +95,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# an environment with potentially malicious users.
|
||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# If set, disables TPU-specific optimization for top-k & top-p sampling
|
||||
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
|
||||
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
|
||||
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -66,7 +66,14 @@ class TopKTopPSampler(nn.Module):
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_tpu():
|
||||
self.forward = self.forward_tpu
|
||||
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
|
||||
logger.warning(
|
||||
"TPU-specific optimization for top-k & top-p sampling are "
|
||||
"disabled, falling back to PyTorch-native implementation "
|
||||
"which could be very slow.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_tpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
@ -105,8 +112,19 @@ class TopKTopPSampler(nn.Module):
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# TODO Placeholder for TPU optimized topk/p kernel
|
||||
# logits = apply_top_k_top_p(logits, k, p)
|
||||
# If only top-k is specified, use pytorch's builtin topk op. This leads
|
||||
# to significant speed up on TPU compared to using apply_top_k_top_p.
|
||||
if k is not None and p is None:
|
||||
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
|
||||
|
||||
mask = torch.ones_like(logits, dtype=torch.bool)
|
||||
mask.scatter_(-1, topk_indices, False)
|
||||
logits.masked_fill_(mask, float('-inf'))
|
||||
else:
|
||||
# TODO Placeholder for TPU optimized topp kernel
|
||||
# logits = apply_top_k_top_p(logits, k, p)
|
||||
pass
|
||||
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
Reference in New Issue
Block a user