diff --git a/docker/Dockerfile b/docker/Dockerfile index f9e07acb85..8f482b393c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -359,8 +359,8 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer pre-compiled kernel cache and binaries # https://docs.flashinfer.ai/installation.html RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system flashinfer-cubin==0.4.0 \ - && uv pip install --system flashinfer-jit-cache==0.4.0 \ + uv pip install --system flashinfer-cubin==0.4.1 \ + && uv pip install --system flashinfer-jit-cache==0.4.1 \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ && flashinfer show-config diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 165256a9bd..6dfa560178 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. # build flashinfer for torch nightly from source around 10 mins -# release version: v0.4.0 +# release version: v0.4.1 # todo(elainewy): cache flashinfer build result for faster build ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ @@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ echo "git clone flashinfer..." \ && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ && cd flashinfer \ - && git checkout v0.4.0 \ + && git checkout v0.4.1\ && git submodule update --init --recursive \ && echo "finish git clone flashinfer..." \ && rm -rf build \ diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 06956415d0..411c8de537 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -12,4 +12,4 @@ torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytor # https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 # FlashInfer should be updated together with the Dockerfile -flashinfer-python==0.4.0 \ No newline at end of file +flashinfer-python==0.4.1 \ No newline at end of file diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index 7a5d10a87b..91b508d416 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -37,7 +37,7 @@ if TRTLLM_GEN_MXFP4_AVAILABLE: trtllm_fp4_block_scale_moe, ) from flashinfer.fp4_quantization import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache @dataclass @@ -319,7 +319,7 @@ def tg_mxfp4_moe( if transpose_optimized: for i in range(num_experts): # w13 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, @@ -330,7 +330,7 @@ def tg_mxfp4_moe( .contiguous() ) # w13 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -344,7 +344,7 @@ def tg_mxfp4_moe( ) ) # w13 bias shuffling - permute_bias_indices = _maybe_get_cached_w2_permute_indices( + permute_bias_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -356,7 +356,7 @@ def tg_mxfp4_moe( .contiguous() ) # w2 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, @@ -367,7 +367,7 @@ def tg_mxfp4_moe( .contiguous() ) # w2 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -381,7 +381,7 @@ def tg_mxfp4_moe( ) ) # w2 bias shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 0b0048c645..e305483eb1 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) -from vllm.utils import next_power_of_2 class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -65,30 +64,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (M, K) return (workspace1, workspace2, output) - def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int): - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # 1.0 means perfect expert distribution. - # > 1.0 means some experts have more tokens than the perfect - # distribution. - # < 1.0 does not make sense. - imbalance_factor = 1.3 - # Calculate the number of tokens per expert assuming perfect - # distribution. - num_tokens_per_expert = (num_tokens * top_k) // local_num_experts - # Apply the imbalance factor. - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - - return tile_tokens_dim - def apply( self, output: torch.Tensor, @@ -148,9 +123,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): "local_expert_offset": local_expert_offset, "local_num_experts": local_num_experts, "routed_scaling_factor": None, - "tile_tokens_dim": self._get_tile_tokens_dim( - x_quant, topk, local_num_experts - ), + "tile_tokens_dim": None, "routing_method_type": 1, "do_finalize": True, "output": output, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 41f82de4ff..9d496f72eb 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -72,7 +72,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types -from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, @@ -1125,16 +1124,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): return out.view(*output_shape) -def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int: - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. @@ -1332,8 +1321,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ): from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( - _maybe_get_cached_w2_permute_indices, _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, ) """Prepare quantized weights for kernel (done offline with weights).""" @@ -1394,7 +1383,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) ) - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m, @@ -1405,7 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): .contiguous() ) - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, @@ -1664,9 +1653,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, - tile_tokens_dim=_get_tile_tokens_dim( - x.shape[0], top_k, layer.local_num_experts - ), + tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, )[0] diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 2eda2abfb4..5b35cf6df8 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -50,7 +50,6 @@ from vllm.scalar_type import scalar_types from vllm.utils import ( has_triton_kernels, is_torch_equal_or_newer, - next_power_of_2, round_up, ) from vllm.utils.flashinfer import has_flashinfer @@ -97,12 +96,6 @@ def get_mxfp4_backend(): and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): - logger.info_once( - "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " - "for high concurrency throughput workloads consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " - "performance" - ) return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM elif current_platform.is_device_capability(100) and has_flashinfer(): logger.info_once( @@ -357,7 +350,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 ): from flashinfer.fp4_quantization import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache layer.gemm1_alpha = Parameter( torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), @@ -449,7 +442,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): # w13 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, @@ -460,7 +453,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): .contiguous() ) # w13 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -476,7 +469,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) ) # w13 bias shuffling - permute_bias_indices = _maybe_get_cached_w2_permute_indices( + permute_bias_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -488,7 +481,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): .contiguous() ) # w2 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, @@ -499,7 +492,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): .contiguous() ) # w2 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -515,7 +508,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) ) # w2 bias shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -735,30 +728,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") - def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # - 1.0 means perfect expert distribution. - # - > 1.0 means some experts have more - # tokens than the perfect distribution. - # - < 1.0 does not make sense. - imbalance_factor = 1.3 - # Calculate the number of tokens per expert - # assuming perfect distribution. - num_tokens_per_expert = (num_tokens * top_k) // self.num_experts - # Apply the imbalance factor. - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - - return tile_tokens_dim - def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -1037,7 +1006,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.ep_rank * layer.local_num_experts, # local_expert_offset self.num_experts, # local num experts None, - self._get_tile_tokens_dim(x, top_k), + None, 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize tune_max_num_tokens=self.max_capture_size,