mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Fall back if flashinfer comm module not found (#20936)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@ -20,10 +20,12 @@ from vllm.utils import direct_register_custom_op
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if find_spec("flashinfer"):
|
||||
import flashinfer.comm as flashinfer_comm
|
||||
|
||||
flashinfer_comm = (flashinfer_comm if hasattr(
|
||||
flashinfer_comm, "trtllm_allreduce_fusion") else None)
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm
|
||||
flashinfer_comm = (flashinfer_comm if hasattr(
|
||||
flashinfer_comm, "trtllm_allreduce_fusion") else None)
|
||||
except ImportError:
|
||||
flashinfer_comm = None
|
||||
else:
|
||||
flashinfer_comm = None
|
||||
from vllm.platforms import current_platform
|
||||
@ -411,7 +413,8 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
use_fp32_lamport = self.model_dtype == torch.float32
|
||||
if flashinfer_comm is None:
|
||||
logger.warning(
|
||||
"Flashinfer is not installed, skipping allreduce fusion pass")
|
||||
"Flashinfer is not installed or comm module not found, "
|
||||
"skipping allreduce fusion pass")
|
||||
return
|
||||
# Check if the world size is supported
|
||||
if self.tp_size not in _FI_MAX_SIZES:
|
||||
|
Reference in New Issue
Block a user