mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Kernel] Lazy import FlashInfer (#26977)
This commit is contained in:
@ -5,20 +5,13 @@ import torch
|
|||||||
from torch import Generator
|
from torch import Generator
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
apply_top_k_top_p,
|
|
||||||
is_flashinfer_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEVICE = current_platform.device_type
|
DEVICE = current_platform.device_type
|
||||||
|
|
||||||
BATCH_SIZE = 1024
|
BATCH_SIZE = 1024
|
||||||
VOCAB_SIZE = 128 * 1024
|
VOCAB_SIZE = 128 * 1024
|
||||||
|
|
||||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
|
||||||
if is_flashinfer_available:
|
|
||||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_default_device():
|
def reset_default_device():
|
||||||
@ -65,6 +58,14 @@ def test_flashinfer_sampler():
|
|||||||
sampling results due to randomness), so we will compare the probability
|
sampling results due to randomness), so we will compare the probability
|
||||||
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||||
|
|
||||||
|
is_flashinfer_available = True
|
||||||
|
except ImportError:
|
||||||
|
is_flashinfer_available = False
|
||||||
|
|
||||||
|
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||||
|
|
||||||
if not FLASHINFER_ENABLED:
|
if not FLASHINFER_ENABLED:
|
||||||
pytest.skip("FlashInfer not installed or not available on this platform.")
|
pytest.skip("FlashInfer not installed or not available on this platform.")
|
||||||
|
@ -13,13 +13,6 @@ from vllm.platforms import CpuArchEnum, current_platform
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
import flashinfer.sampling
|
|
||||||
|
|
||||||
is_flashinfer_available = True
|
|
||||||
except ImportError:
|
|
||||||
is_flashinfer_available = False
|
|
||||||
|
|
||||||
|
|
||||||
class TopKTopPSampler(nn.Module):
|
class TopKTopPSampler(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -38,32 +31,18 @@ class TopKTopPSampler(nn.Module):
|
|||||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||||
and current_platform.is_cuda()
|
and current_platform.is_cuda()
|
||||||
):
|
):
|
||||||
if is_flashinfer_available:
|
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||||
flashinfer_version = flashinfer.__version__
|
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
||||||
if version.parse(flashinfer_version) < version.parse("0.2.3"):
|
logger.info_once("Using FlashInfer for top-p & top-k sampling.")
|
||||||
logger.warning_once(
|
self.forward = self.forward_cuda
|
||||||
"FlashInfer version >= 0.2.3 required. "
|
|
||||||
"Falling back to default sampling implementation."
|
|
||||||
)
|
|
||||||
self.forward = self.forward_native
|
|
||||||
elif envs.VLLM_USE_FLASHINFER_SAMPLER:
|
|
||||||
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
|
||||||
logger.info_once("Using FlashInfer for top-p & top-k sampling.")
|
|
||||||
self.forward = self.forward_cuda
|
|
||||||
else:
|
|
||||||
logger.debug_once(
|
|
||||||
"FlashInfer top-p/top-k sampling is available but disabled "
|
|
||||||
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
|
|
||||||
"after verifying accuracy for your workloads."
|
|
||||||
)
|
|
||||||
self.forward = self.forward_native
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
logger.debug_once(
|
||||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
"FlashInfer top-p/top-k sampling is available but disabled "
|
||||||
"native implementation of top-p & top-k sampling. For the "
|
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
|
||||||
"best performance, please install FlashInfer."
|
"after verifying accuracy for your workloads."
|
||||||
)
|
)
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
|
||||||
elif current_platform.is_cpu():
|
elif current_platform.is_cpu():
|
||||||
arch = current_platform.get_cpu_architecture()
|
arch = current_platform.get_cpu_architecture()
|
||||||
# Fall back to native implementation for POWERPC and RISCV.
|
# Fall back to native implementation for POWERPC and RISCV.
|
||||||
@ -278,6 +257,13 @@ def flashinfer_sample(
|
|||||||
does not. Call this function at the end of the forward pass to minimize
|
does not. Call this function at the end of the forward pass to minimize
|
||||||
the synchronization overhead.
|
the synchronization overhead.
|
||||||
"""
|
"""
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
|
||||||
|
raise ImportError(
|
||||||
|
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
|
||||||
|
)
|
||||||
|
|
||||||
assert not (k is None and p is None)
|
assert not (k is None and p is None)
|
||||||
if k is None:
|
if k is None:
|
||||||
# Top-p only.
|
# Top-p only.
|
||||||
|
Reference in New Issue
Block a user