[V1] Use FlashInfer by default on Blackwell GPUs (#19118)

This commit is contained in:
Michael Goin
2025-06-05 15:40:39 -04:00
committed by GitHub
parent aa49f14832
commit 87360308b7
2 changed files with 39 additions and 0 deletions

View File

@ -229,6 +229,21 @@ class CudaPlatformBase(Platform):
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try:
import flashinfer # noqa: F401
logger.info_once(
"Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends."
"flashinfer.FlashInferBackend")
except ImportError:
logger.info_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."

View File

@ -228,6 +228,30 @@ class Platform:
return current_capability.to_int() >= capability
@classmethod
def is_device_capability(
cls,
capability: Union[tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
Test whether this platform has exactly the specified device capability.
The `capability` argument can either be:
- A tuple `(major, minor)`.
- An integer `<major><minor>`. (See
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability == capability
return current_capability.to_int() == capability
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""