Fall back if flashinfer comm module not found (#20936)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-07-14 16:11:18 -07:00
committed by GitHub
parent 55e1c66da5
commit 61e20828da

View File

@ -20,10 +20,12 @@ from vllm.utils import direct_register_custom_op
from .vllm_inductor_pass import VllmInductorPass
if find_spec("flashinfer"):
import flashinfer.comm as flashinfer_comm
flashinfer_comm = (flashinfer_comm if hasattr(
flashinfer_comm, "trtllm_allreduce_fusion") else None)
try:
import flashinfer.comm as flashinfer_comm
flashinfer_comm = (flashinfer_comm if hasattr(
flashinfer_comm, "trtllm_allreduce_fusion") else None)
except ImportError:
flashinfer_comm = None
else:
flashinfer_comm = None
from vllm.platforms import current_platform
@ -411,7 +413,8 @@ class AllReduceFusionPass(VllmInductorPass):
use_fp32_lamport = self.model_dtype == torch.float32
if flashinfer_comm is None:
logger.warning(
"Flashinfer is not installed, skipping allreduce fusion pass")
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass")
return
# Check if the world size is supported
if self.tp_size not in _FI_MAX_SIZES: