mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Quantization/NVFP4] Speed up TRTLLM NVFP4 MOE weight loading and fix K/V scale loading for MLA Attn (#25968)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
committed by
yewentao256
parent
9ea82ecd25
commit
920db41128
@ -86,7 +86,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
logger.warning_once(
|
||||
"Checkpoint does not provide a q scaling factor. "
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"the flash-attn backend.")
|
||||
"FP8 Attention backends (flash-attn or flashinfer).")
|
||||
layer._q_scale.copy_(k_scale)
|
||||
layer._q_scale_float = k_scale
|
||||
|
||||
@ -98,9 +98,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
if (k_scale == 1.0 and v_scale == 1.0
|
||||
and "e5m2" not in layer.kv_cache_dtype):
|
||||
logger.warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint.")
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. "
|
||||
"If this is unintended, verify that k/v_scale "
|
||||
"scaling factors are properly set in the checkpoint.")
|
||||
|
||||
if layer.q_scale > 0.0:
|
||||
q_scale = layer.q_scale
|
||||
|
@ -1064,7 +1064,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
@ -1197,19 +1197,23 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def prepare_static_weight_layouts_for_trtllm_moe(
|
||||
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
self,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm1_scales_linear_fp4_bytes: torch.Tensor,
|
||||
gemm2_scales_linear_fp4_bytes: torch.Tensor,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# args_dequant,
|
||||
# args,
|
||||
gemm1_weights,
|
||||
gemm2_weights,
|
||||
gemm1_scales_linear_fp4_bytes,
|
||||
gemm2_scales_linear_fp4_bytes,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_experts,
|
||||
):
|
||||
from flashinfer import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w2_permute_indices,
|
||||
_maybe_get_cached_w3_w1_permute_indices)
|
||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||
from flashinfer import (reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a, shuffle_matrix_sf_a)
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
|
||||
# Convert quantized weights to proper formats
|
||||
@ -1227,48 +1231,54 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
intermediate_size //
|
||||
16) # fp8 scaling factors
|
||||
|
||||
# Reorder rows of W1 and scales for fused gated activation
|
||||
gemm1_weights_fp4_interleaved = []
|
||||
gemm1_scales_fp4_interleaved = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
|
||||
gemm1_scales_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
gemm1_scales_linear_fp4[i].clone()))
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
gemm1_weights_fp4_interleaved = torch.stack(
|
||||
gemm1_weights_fp4_interleaved).reshape(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 2)
|
||||
gemm1_scales_fp4_interleaved = torch.stack(
|
||||
gemm1_scales_fp4_interleaved).reshape(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 16)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
# Calculate the permute indices for the following:
|
||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||
# for both w3_w1 and w2 weights and scale factors
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm1_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
gemm1_weights_fp4.device)].contiguous())
|
||||
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
gemm1_scales_linear_fp4.device)].contiguous()))
|
||||
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm2_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
gemm2_weights_fp4.device)].contiguous())
|
||||
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
gemm2_scales_linear_fp4.device)].contiguous()))
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||
@ -1283,8 +1293,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
torch.stack(gemm2_scales_fp4_shuffled).view(
|
||||
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||
intermediate_size // 16))
|
||||
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
|
||||
return (
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1 processing
|
||||
@ -1334,9 +1348,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
if self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
# Prepare static weights for TRT-LLM kernel
|
||||
# alternate: prepare_static_weight_layouts_for_trtllm_moe
|
||||
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
|
||||
) = self.prepare_static_weight_layouts_for_trtllm_moe(
|
||||
) = self.prepare_static_weights_for_trtllm_fp4_moe(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
@ -1345,6 +1360,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||
layer.w13_weight.size(0), # num_experts
|
||||
)
|
||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||
|
||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False)
|
||||
|
@ -1003,12 +1003,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
if any("mla_attn" in key for key in params_dict):
|
||||
attn_str = "mla_attn.mla_attn"
|
||||
logger.debug_once(f"Found mla_attn with k_scale and v_scale in "
|
||||
f"the checkpoint, using {attn_str} as attn_str")
|
||||
else:
|
||||
attn_str = "attn"
|
||||
# Define scale name mapping patterns in order of precedence
|
||||
scale_mapping_patterns = [
|
||||
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
|
||||
# .self_attn.attn.{k,v}_scale
|
||||
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
|
||||
r".self_attn.attn.\2_scale"),
|
||||
rf".self_attn.{attn_str}.\2_scale"),
|
||||
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
|
||||
# .self_attn.attn.{k,v}_scale
|
||||
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
|
||||
|
Reference in New Issue
Block a user