mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Use FlashInfer by default on Blackwell GPUs (#19118)
This commit is contained in:
@ -229,6 +229,21 @@ class CudaPlatformBase(Platform):
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
if cls.is_device_capability(100):
|
||||
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend on V1 engine by default for "
|
||||
"Blackwell (SM 10.0) GPUs.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flashinfer.FlashInferBackend")
|
||||
except ImportError:
|
||||
logger.info_once(
|
||||
"FlashInfer failed to import for V1 engine on "
|
||||
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||
"install FlashInfer for better performance.")
|
||||
pass
|
||||
if cls.has_device_capability(80):
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
|
@ -228,6 +228,30 @@ class Platform:
|
||||
|
||||
return current_capability.to_int() >= capability
|
||||
|
||||
@classmethod
|
||||
def is_device_capability(
|
||||
cls,
|
||||
capability: Union[tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Test whether this platform has exactly the specified device capability.
|
||||
|
||||
The `capability` argument can either be:
|
||||
|
||||
- A tuple `(major, minor)`.
|
||||
- An integer `<major><minor>`. (See
|
||||
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
|
||||
"""
|
||||
current_capability = cls.get_device_capability(device_id=device_id)
|
||||
if current_capability is None:
|
||||
return False
|
||||
|
||||
if isinstance(capability, tuple):
|
||||
return current_capability == capability
|
||||
|
||||
return current_capability.to_int() == capability
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Get the name of a device."""
|
||||
|
Reference in New Issue
Block a user