[Quantization/NVFP4] Speed up TRTLLM NVFP4 MOE weight loading and fix K/V scale loading for MLA Attn (#25968)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Pavani Majety
2025-10-03 12:35:06 -07:00
committed by yewentao256
parent 9ea82ecd25
commit 920db41128
3 changed files with 77 additions and 55 deletions

View File

@ -86,7 +86,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
logger.warning_once( logger.warning_once(
"Checkpoint does not provide a q scaling factor. " "Checkpoint does not provide a q scaling factor. "
"Setting it to k_scale. This only matters for " "Setting it to k_scale. This only matters for "
"the flash-attn backend.") "FP8 Attention backends (flash-attn or flashinfer).")
layer._q_scale.copy_(k_scale) layer._q_scale.copy_(k_scale)
layer._q_scale_float = k_scale layer._q_scale_float = k_scale
@ -98,9 +98,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
if (k_scale == 1.0 and v_scale == 1.0 if (k_scale == 1.0 and v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype): and "e5m2" not in layer.kv_cache_dtype):
logger.warning_once( logger.warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This " "Using KV cache scaling factor 1.0 for fp8_e4m3. "
"may cause accuracy issues. Please make sure k/v_scale " "If this is unintended, verify that k/v_scale "
"scaling factors are available in the fp8 checkpoint.") "scaling factors are properly set in the checkpoint.")
if layer.q_scale > 0.0: if layer.q_scale > 0.0:
q_scale = layer.q_scale q_scale = layer.q_scale

View File

@ -1064,7 +1064,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.flashinfer_moe_backend = None self.flashinfer_moe_backend = None
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
if self.allow_flashinfer: if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once( logger.info_once(
@ -1197,19 +1197,23 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_loader=weight_loader) weight_loader=weight_loader)
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def prepare_static_weight_layouts_for_trtllm_moe( def prepare_static_weights_for_trtllm_fp4_moe(
self, self,
gemm1_weights: torch.Tensor, # args_dequant,
gemm2_weights: torch.Tensor, # args,
gemm1_scales_linear_fp4_bytes: torch.Tensor, gemm1_weights,
gemm2_scales_linear_fp4_bytes: torch.Tensor, gemm2_weights,
hidden_size: int, gemm1_scales_linear_fp4_bytes,
intermediate_size: int, gemm2_scales_linear_fp4_bytes,
num_experts: int, hidden_size,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: intermediate_size,
num_experts,
):
from flashinfer import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices)
"""Prepare quantized weights for kernel (done offline with weights).""" """Prepare quantized weights for kernel (done offline with weights)."""
from flashinfer import (reorder_rows_for_gated_act_gemm,
shuffle_matrix_a, shuffle_matrix_sf_a)
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats # Convert quantized weights to proper formats
@ -1227,48 +1231,54 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
intermediate_size // intermediate_size //
16) # fp8 scaling factors 16) # fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(
gemm1_scales_linear_fp4[i].clone()))
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved).reshape(num_experts,
2 * intermediate_size,
hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved).reshape(num_experts,
2 * intermediate_size,
hidden_size // 16)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = [] gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = [] gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = [] gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = [] gemm2_scales_fp4_shuffled = []
for i in range(num_experts): for i in range(num_experts):
gemm1_weights_fp4_shuffled.append( # Calculate the permute indices for the following:
shuffle_matrix_a( # 1. Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved[i].view(torch.uint8), # 2. Shuffle weights and scaling factors for transposed mma output
epilogue_tile_m)) # for both w3_w1 and w2 weights and scale factors
gemm1_scales_fp4_shuffled.append( permute_indices = _maybe_get_cached_w3_w1_permute_indices(
shuffle_matrix_sf_a( self._cache_permute_indices,
gemm1_scales_fp4_interleaved[i].view(torch.uint8), gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m)) epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view(
torch.uint8)[permute_indices.to(
gemm1_weights_fp4.device)].contiguous())
gemm2_weights_fp4_shuffled.append( permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8), self._cache_permute_indices,
epilogue_tile_m)) gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view(
torch.uint8)[permute_sf_indices.to(
gemm1_scales_linear_fp4.device)].contiguous()))
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view(
torch.uint8)[permute_indices.to(
gemm2_weights_fp4.device)].contiguous())
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append( gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a( nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view(
gemm2_scales_linear_fp4[i].view(torch.uint8), torch.uint8)[permute_sf_indices.to(
epilogue_tile_m)) gemm2_scales_linear_fp4.device)].contiguous()))
# Stack weights for all experts # Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
@ -1283,8 +1293,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
torch.stack(gemm2_scales_fp4_shuffled).view( torch.stack(gemm2_scales_fp4_shuffled).view(
torch.float8_e4m3fn).reshape(num_experts, hidden_size, torch.float8_e4m3fn).reshape(num_experts, hidden_size,
intermediate_size // 16)) intermediate_size // 16))
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, return (
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled) gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1 processing # GEMM 1 processing
@ -1334,9 +1348,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if self.allow_flashinfer and \ if self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# Prepare static weights for TRT-LLM kernel # Prepare static weights for TRT-LLM kernel
# alternate: prepare_static_weight_layouts_for_trtllm_moe
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
) = self.prepare_static_weight_layouts_for_trtllm_moe( ) = self.prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
layer.w13_weight_scale, layer.w13_weight_scale,
@ -1345,6 +1360,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight.size(-2) // 2, # intermediate_size layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts layer.w13_weight.size(0), # num_experts
) )
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter( layer.gemm1_weights_fp4_shuffled = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False) gemm1_weights_fp4_shuffled, requires_grad=False)

View File

@ -1003,12 +1003,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return None return None
return remapped_name return remapped_name
if any("mla_attn" in key for key in params_dict):
attn_str = "mla_attn.mla_attn"
logger.debug_once(f"Found mla_attn with k_scale and v_scale in "
f"the checkpoint, using {attn_str} as attn_str")
else:
attn_str = "attn"
# Define scale name mapping patterns in order of precedence # Define scale name mapping patterns in order of precedence
scale_mapping_patterns = [ scale_mapping_patterns = [
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale # .self_attn.attn.{k,v}_scale
(r"\.self_attn\.([kv])_proj\.([kv])_scale$", (r"\.self_attn\.([kv])_proj\.([kv])_scale$",
r".self_attn.attn.\2_scale"), rf".self_attn.{attn_str}.\2_scale"),
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> # QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale # .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),