[Kernel] Lazy import FlashInfer (#26977)

This commit is contained in:
Jee Jee Li
2025-10-17 12:48:18 +08:00
committed by GitHub
parent 87bc0c492f
commit fec2b341ad
2 changed files with 25 additions and 38 deletions

View File

@ -5,20 +5,13 @@ import torch
from torch import Generator from torch import Generator
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import ( from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
apply_top_k_top_p,
is_flashinfer_available,
)
DEVICE = current_platform.device_type DEVICE = current_platform.device_type
BATCH_SIZE = 1024 BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024 VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if is_flashinfer_available:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_default_device(): def reset_default_device():
@ -65,6 +58,14 @@ def test_flashinfer_sampler():
sampling results due to randomness), so we will compare the probability sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation. renormed consequently by top-k and then top-p of FlashInfer implementation.
""" """
try:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if not FLASHINFER_ENABLED: if not FLASHINFER_ENABLED:
pytest.skip("FlashInfer not installed or not available on this platform.") pytest.skip("FlashInfer not installed or not available on this platform.")

View File

@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
class TopKTopPSampler(nn.Module): class TopKTopPSampler(nn.Module):
""" """
@ -38,32 +31,18 @@ class TopKTopPSampler(nn.Module):
logprobs_mode not in ("processed_logits", "processed_logprobs") logprobs_mode not in ("processed_logits", "processed_logprobs")
and current_platform.is_cuda() and current_platform.is_cuda()
): ):
if is_flashinfer_available: if envs.VLLM_USE_FLASHINFER_SAMPLER:
flashinfer_version = flashinfer.__version__ # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
if version.parse(flashinfer_version) < version.parse("0.2.3"): logger.info_once("Using FlashInfer for top-p & top-k sampling.")
logger.warning_once( self.forward = self.forward_cuda
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation."
)
self.forward = self.forward_native
elif envs.VLLM_USE_FLASHINFER_SAMPLER:
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger.info_once("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
logger.debug_once(
"FlashInfer top-p/top-k sampling is available but disabled "
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
"after verifying accuracy for your workloads."
)
self.forward = self.forward_native
else: else:
logger.warning_once( logger.debug_once(
"FlashInfer is not available. Falling back to the PyTorch-" "FlashInfer top-p/top-k sampling is available but disabled "
"native implementation of top-p & top-k sampling. For the " "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
"best performance, please install FlashInfer." "after verifying accuracy for your workloads."
) )
self.forward = self.forward_native self.forward = self.forward_native
elif current_platform.is_cpu(): elif current_platform.is_cpu():
arch = current_platform.get_cpu_architecture() arch = current_platform.get_cpu_architecture()
# Fall back to native implementation for POWERPC and RISCV. # Fall back to native implementation for POWERPC and RISCV.
@ -278,6 +257,13 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize does not. Call this function at the end of the forward pass to minimize
the synchronization overhead. the synchronization overhead.
""" """
import flashinfer
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
raise ImportError(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert not (k is None and p is None) assert not (k is None and p is None)
if k is None: if k is None:
# Top-p only. # Top-p only.