[Bugfix][V1] Fix flashinfer sampling (#14815)

This commit is contained in:
DefTruth
2025-03-15 11:42:38 +08:00
committed by GitHub
parent 9f37422779
commit acaea3bb07

View File

@ -24,7 +24,24 @@ class TopKTopPSampler(nn.Module):
super().__init__()
if current_platform.is_cuda():
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
flashinfer_version = flashinfer.__version__
if flashinfer_version >= "0.2.3":
# FIXME(DefTruth): Currently, we have errors when using
# FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
# workaround, we disable FlashInfer for top-p & top-k
# sampling by default while FlashInfer>=v0.2.3.
# The sampling API removes the success return value
# of all sampling API, which is not compatible with
# earlier design.
# https://github.com/flashinfer-ai/flashinfer/releases/
# tag/v0.2.3
logger.info(
"Currently, FlashInfer top-p & top-k sampling sampler "
"is disabled because FlashInfer>=v0.2.3 is not "
"backward compatible. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling.")
self.forward = self.forward_native
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set