diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 5892669a3a..97cb2995cb 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -20,10 +20,12 @@ from vllm.utils import direct_register_custom_op from .vllm_inductor_pass import VllmInductorPass if find_spec("flashinfer"): - import flashinfer.comm as flashinfer_comm - - flashinfer_comm = (flashinfer_comm if hasattr( - flashinfer_comm, "trtllm_allreduce_fusion") else None) + try: + import flashinfer.comm as flashinfer_comm + flashinfer_comm = (flashinfer_comm if hasattr( + flashinfer_comm, "trtllm_allreduce_fusion") else None) + except ImportError: + flashinfer_comm = None else: flashinfer_comm = None from vllm.platforms import current_platform @@ -411,7 +413,8 @@ class AllReduceFusionPass(VllmInductorPass): use_fp32_lamport = self.model_dtype == torch.float32 if flashinfer_comm is None: logger.warning( - "Flashinfer is not installed, skipping allreduce fusion pass") + "Flashinfer is not installed or comm module not found, " + "skipping allreduce fusion pass") return # Check if the world size is supported if self.tp_size not in _FI_MAX_SIZES: