diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 275a1c43fd..27e2b7846d 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -86,7 +86,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): logger.warning_once( "Checkpoint does not provide a q scaling factor. " "Setting it to k_scale. This only matters for " - "the flash-attn backend.") + "FP8 Attention backends (flash-attn or flashinfer).") layer._q_scale.copy_(k_scale) layer._q_scale_float = k_scale @@ -98,9 +98,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): if (k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): logger.warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + "Using KV cache scaling factor 1.0 for fp8_e4m3. " + "If this is unintended, verify that k/v_scale " + "scaling factors are properly set in the checkpoint.") if layer.q_scale > 0.0: q_scale = layer.q_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 20704439ea..1ca82cdcbc 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1064,7 +1064,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.flashinfer_moe_backend = None - + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( @@ -1197,19 +1197,23 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): weight_loader=weight_loader) layer.register_parameter("w2_input_scale", w2_input_scale) - def prepare_static_weight_layouts_for_trtllm_moe( + def prepare_static_weights_for_trtllm_fp4_moe( self, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm1_scales_linear_fp4_bytes: torch.Tensor, - gemm2_scales_linear_fp4_bytes: torch.Tensor, - hidden_size: int, - intermediate_size: int, - num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # args_dequant, + # args, + gemm1_weights, + gemm2_weights, + gemm1_scales_linear_fp4_bytes, + gemm2_scales_linear_fp4_bytes, + hidden_size, + intermediate_size, + num_experts, + ): + from flashinfer import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w2_permute_indices, + _maybe_get_cached_w3_w1_permute_indices) """Prepare quantized weights for kernel (done offline with weights).""" - from flashinfer import (reorder_rows_for_gated_act_gemm, - shuffle_matrix_a, shuffle_matrix_sf_a) epilogue_tile_m = 128 # FIXME: this depends on the kernel internals # Convert quantized weights to proper formats @@ -1227,48 +1231,54 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): intermediate_size // 16) # fp8 scaling factors - # Reorder rows of W1 and scales for fused gated activation - gemm1_weights_fp4_interleaved = [] - gemm1_scales_fp4_interleaved = [] - for i in range(num_experts): - gemm1_weights_fp4_interleaved.append( - reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())) - gemm1_scales_fp4_interleaved.append( - reorder_rows_for_gated_act_gemm( - gemm1_scales_linear_fp4[i].clone())) - - # Stack weights and scales for all experts - gemm1_weights_fp4_interleaved = torch.stack( - gemm1_weights_fp4_interleaved).reshape(num_experts, - 2 * intermediate_size, - hidden_size // 2) - gemm1_scales_fp4_interleaved = torch.stack( - gemm1_scales_fp4_interleaved).reshape(num_experts, - 2 * intermediate_size, - hidden_size // 16) - - # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] gemm2_weights_fp4_shuffled = [] gemm2_scales_fp4_shuffled = [] for i in range(num_experts): - gemm1_weights_fp4_shuffled.append( - shuffle_matrix_a( - gemm1_weights_fp4_interleaved[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_scales_fp4_shuffled.append( - shuffle_matrix_sf_a( - gemm1_scales_fp4_interleaved[i].view(torch.uint8), - epilogue_tile_m)) + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view( + torch.uint8)[permute_indices.to( + gemm1_weights_fp4.device)].contiguous()) - gemm2_weights_fp4_shuffled.append( - shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8), - epilogue_tile_m)) + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_fp4_shuffled.append( + nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view( + torch.uint8)[permute_sf_indices.to( + gemm1_scales_linear_fp4.device)].contiguous())) + + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + gemm2_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view( + torch.uint8)[permute_indices.to( + gemm2_weights_fp4.device)].contiguous()) + + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + gemm2_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm2_scales_fp4_shuffled.append( - shuffle_matrix_sf_a( - gemm2_scales_linear_fp4[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view( + torch.uint8)[permute_sf_indices.to( + gemm2_scales_linear_fp4.device)].contiguous())) # Stack weights for all experts gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) @@ -1283,8 +1293,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): torch.stack(gemm2_scales_fp4_shuffled).view( torch.float8_e4m3fn).reshape(num_experts, hidden_size, intermediate_size // 16)) - return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled) + return ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 processing @@ -1334,9 +1348,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): if self.allow_flashinfer and \ self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: # Prepare static weights for TRT-LLM kernel + # alternate: prepare_static_weight_layouts_for_trtllm_moe (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled - ) = self.prepare_static_weight_layouts_for_trtllm_moe( + ) = self.prepare_static_weights_for_trtllm_fp4_moe( layer.w13_weight, layer.w2_weight, layer.w13_weight_scale, @@ -1345,6 +1360,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w13_weight.size(-2) // 2, # intermediate_size layer.w13_weight.size(0), # num_experts ) + logger.debug_once("Finished shuffling weights for TRT-LLM MOE") layer.gemm1_weights_fp4_shuffled = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index bbed43b175..6c5f7bbcc8 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1003,12 +1003,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return None return remapped_name + if any("mla_attn" in key for key in params_dict): + attn_str = "mla_attn.mla_attn" + logger.debug_once(f"Found mla_attn with k_scale and v_scale in " + f"the checkpoint, using {attn_str} as attn_str") + else: + attn_str = "attn" # Define scale name mapping patterns in order of precedence scale_mapping_patterns = [ # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale (r"\.self_attn\.([kv])_proj\.([kv])_scale$", - r".self_attn.attn.\2_scale"), + rf".self_attn.{attn_str}.\2_scale"), # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),