mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][V1] Fix flashinfer sampling (#14815)
This commit is contained in:
@ -24,7 +24,24 @@ class TopKTopPSampler(nn.Module):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if is_flashinfer_available:
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
flashinfer_version = flashinfer.__version__
|
||||
if flashinfer_version >= "0.2.3":
|
||||
# FIXME(DefTruth): Currently, we have errors when using
|
||||
# FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
|
||||
# workaround, we disable FlashInfer for top-p & top-k
|
||||
# sampling by default while FlashInfer>=v0.2.3.
|
||||
# The sampling API removes the success return value
|
||||
# of all sampling API, which is not compatible with
|
||||
# earlier design.
|
||||
# https://github.com/flashinfer-ai/flashinfer/releases/
|
||||
# tag/v0.2.3
|
||||
logger.info(
|
||||
"Currently, FlashInfer top-p & top-k sampling sampler "
|
||||
"is disabled because FlashInfer>=v0.2.3 is not "
|
||||
"backward compatible. Falling back to the PyTorch-"
|
||||
"native implementation of top-p & top-k sampling.")
|
||||
self.forward = self.forward_native
|
||||
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||
# default it is unused). For backward compatibility, we set
|
||||
|
Reference in New Issue
Block a user