mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Quantization/NVFP4] Speed up TRTLLM NVFP4 MOE weight loading and fix K/V scale loading for MLA Attn (#25968)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
committed by
yewentao256
parent
9ea82ecd25
commit
920db41128
@ -86,7 +86,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Checkpoint does not provide a q scaling factor. "
|
"Checkpoint does not provide a q scaling factor. "
|
||||||
"Setting it to k_scale. This only matters for "
|
"Setting it to k_scale. This only matters for "
|
||||||
"the flash-attn backend.")
|
"FP8 Attention backends (flash-attn or flashinfer).")
|
||||||
layer._q_scale.copy_(k_scale)
|
layer._q_scale.copy_(k_scale)
|
||||||
layer._q_scale_float = k_scale
|
layer._q_scale_float = k_scale
|
||||||
|
|
||||||
@ -98,9 +98,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
if (k_scale == 1.0 and v_scale == 1.0
|
if (k_scale == 1.0 and v_scale == 1.0
|
||||||
and "e5m2" not in layer.kv_cache_dtype):
|
and "e5m2" not in layer.kv_cache_dtype):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
"Using KV cache scaling factor 1.0 for fp8_e4m3. "
|
||||||
"may cause accuracy issues. Please make sure k/v_scale "
|
"If this is unintended, verify that k/v_scale "
|
||||||
"scaling factors are available in the fp8 checkpoint.")
|
"scaling factors are properly set in the checkpoint.")
|
||||||
|
|
||||||
if layer.q_scale > 0.0:
|
if layer.q_scale > 0.0:
|
||||||
q_scale = layer.q_scale
|
q_scale = layer.q_scale
|
||||||
|
@ -1064,7 +1064,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.flashinfer_moe_backend = None
|
self.flashinfer_moe_backend = None
|
||||||
|
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||||
if self.allow_flashinfer:
|
if self.allow_flashinfer:
|
||||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@ -1197,19 +1197,23 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
def prepare_static_weight_layouts_for_trtllm_moe(
|
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||||
self,
|
self,
|
||||||
gemm1_weights: torch.Tensor,
|
# args_dequant,
|
||||||
gemm2_weights: torch.Tensor,
|
# args,
|
||||||
gemm1_scales_linear_fp4_bytes: torch.Tensor,
|
gemm1_weights,
|
||||||
gemm2_scales_linear_fp4_bytes: torch.Tensor,
|
gemm2_weights,
|
||||||
hidden_size: int,
|
gemm1_scales_linear_fp4_bytes,
|
||||||
intermediate_size: int,
|
gemm2_scales_linear_fp4_bytes,
|
||||||
num_experts: int,
|
hidden_size,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
intermediate_size,
|
||||||
|
num_experts,
|
||||||
|
):
|
||||||
|
from flashinfer import nvfp4_block_scale_interleave
|
||||||
|
from flashinfer.fused_moe.core import (
|
||||||
|
_maybe_get_cached_w2_permute_indices,
|
||||||
|
_maybe_get_cached_w3_w1_permute_indices)
|
||||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||||
from flashinfer import (reorder_rows_for_gated_act_gemm,
|
|
||||||
shuffle_matrix_a, shuffle_matrix_sf_a)
|
|
||||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||||
|
|
||||||
# Convert quantized weights to proper formats
|
# Convert quantized weights to proper formats
|
||||||
@ -1227,48 +1231,54 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
intermediate_size //
|
intermediate_size //
|
||||||
16) # fp8 scaling factors
|
16) # fp8 scaling factors
|
||||||
|
|
||||||
# Reorder rows of W1 and scales for fused gated activation
|
|
||||||
gemm1_weights_fp4_interleaved = []
|
|
||||||
gemm1_scales_fp4_interleaved = []
|
|
||||||
for i in range(num_experts):
|
|
||||||
gemm1_weights_fp4_interleaved.append(
|
|
||||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
|
|
||||||
gemm1_scales_fp4_interleaved.append(
|
|
||||||
reorder_rows_for_gated_act_gemm(
|
|
||||||
gemm1_scales_linear_fp4[i].clone()))
|
|
||||||
|
|
||||||
# Stack weights and scales for all experts
|
|
||||||
gemm1_weights_fp4_interleaved = torch.stack(
|
|
||||||
gemm1_weights_fp4_interleaved).reshape(num_experts,
|
|
||||||
2 * intermediate_size,
|
|
||||||
hidden_size // 2)
|
|
||||||
gemm1_scales_fp4_interleaved = torch.stack(
|
|
||||||
gemm1_scales_fp4_interleaved).reshape(num_experts,
|
|
||||||
2 * intermediate_size,
|
|
||||||
hidden_size // 16)
|
|
||||||
|
|
||||||
# Shuffle weights and scaling factors for transposed mma output
|
|
||||||
gemm1_weights_fp4_shuffled = []
|
gemm1_weights_fp4_shuffled = []
|
||||||
gemm1_scales_fp4_shuffled = []
|
gemm1_scales_fp4_shuffled = []
|
||||||
gemm2_weights_fp4_shuffled = []
|
gemm2_weights_fp4_shuffled = []
|
||||||
gemm2_scales_fp4_shuffled = []
|
gemm2_scales_fp4_shuffled = []
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
gemm1_weights_fp4_shuffled.append(
|
# Calculate the permute indices for the following:
|
||||||
shuffle_matrix_a(
|
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||||
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
|
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||||
epilogue_tile_m))
|
# for both w3_w1 and w2 weights and scale factors
|
||||||
gemm1_scales_fp4_shuffled.append(
|
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||||
shuffle_matrix_sf_a(
|
self._cache_permute_indices,
|
||||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
|
gemm1_weights_fp4[i].view(torch.uint8),
|
||||||
epilogue_tile_m))
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view(
|
||||||
|
torch.uint8)[permute_indices.to(
|
||||||
|
gemm1_weights_fp4.device)].contiguous())
|
||||||
|
|
||||||
gemm2_weights_fp4_shuffled.append(
|
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||||
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
|
self._cache_permute_indices,
|
||||||
epilogue_tile_m))
|
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
num_elts_per_sf=16,
|
||||||
|
)
|
||||||
|
gemm1_scales_fp4_shuffled.append(
|
||||||
|
nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view(
|
||||||
|
torch.uint8)[permute_sf_indices.to(
|
||||||
|
gemm1_scales_linear_fp4.device)].contiguous()))
|
||||||
|
|
||||||
|
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||||
|
self._cache_permute_indices,
|
||||||
|
gemm2_weights_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
|
gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view(
|
||||||
|
torch.uint8)[permute_indices.to(
|
||||||
|
gemm2_weights_fp4.device)].contiguous())
|
||||||
|
|
||||||
|
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||||
|
self._cache_permute_indices,
|
||||||
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
num_elts_per_sf=16,
|
||||||
|
)
|
||||||
gemm2_scales_fp4_shuffled.append(
|
gemm2_scales_fp4_shuffled.append(
|
||||||
shuffle_matrix_sf_a(
|
nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view(
|
||||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
torch.uint8)[permute_sf_indices.to(
|
||||||
epilogue_tile_m))
|
gemm2_scales_linear_fp4.device)].contiguous()))
|
||||||
|
|
||||||
# Stack weights for all experts
|
# Stack weights for all experts
|
||||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||||
@ -1283,8 +1293,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
torch.stack(gemm2_scales_fp4_shuffled).view(
|
torch.stack(gemm2_scales_fp4_shuffled).view(
|
||||||
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||||
intermediate_size // 16))
|
intermediate_size // 16))
|
||||||
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
return (
|
||||||
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
|
gemm1_weights_fp4_shuffled,
|
||||||
|
gemm1_scales_fp4_shuffled,
|
||||||
|
gemm2_weights_fp4_shuffled,
|
||||||
|
gemm2_scales_fp4_shuffled,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# GEMM 1 processing
|
# GEMM 1 processing
|
||||||
@ -1334,9 +1348,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
if self.allow_flashinfer and \
|
if self.allow_flashinfer and \
|
||||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
# Prepare static weights for TRT-LLM kernel
|
# Prepare static weights for TRT-LLM kernel
|
||||||
|
# alternate: prepare_static_weight_layouts_for_trtllm_moe
|
||||||
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||||
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
|
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
|
||||||
) = self.prepare_static_weight_layouts_for_trtllm_moe(
|
) = self.prepare_static_weights_for_trtllm_fp4_moe(
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
layer.w13_weight_scale,
|
layer.w13_weight_scale,
|
||||||
@ -1345,6 +1360,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||||
layer.w13_weight.size(0), # num_experts
|
layer.w13_weight.size(0), # num_experts
|
||||||
)
|
)
|
||||||
|
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||||
|
|
||||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||||
gemm1_weights_fp4_shuffled, requires_grad=False)
|
gemm1_weights_fp4_shuffled, requires_grad=False)
|
||||||
|
@ -1003,12 +1003,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
return remapped_name
|
return remapped_name
|
||||||
|
|
||||||
|
if any("mla_attn" in key for key in params_dict):
|
||||||
|
attn_str = "mla_attn.mla_attn"
|
||||||
|
logger.debug_once(f"Found mla_attn with k_scale and v_scale in "
|
||||||
|
f"the checkpoint, using {attn_str} as attn_str")
|
||||||
|
else:
|
||||||
|
attn_str = "attn"
|
||||||
# Define scale name mapping patterns in order of precedence
|
# Define scale name mapping patterns in order of precedence
|
||||||
scale_mapping_patterns = [
|
scale_mapping_patterns = [
|
||||||
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
|
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
|
||||||
# .self_attn.attn.{k,v}_scale
|
# .self_attn.attn.{k,v}_scale
|
||||||
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
|
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
|
||||||
r".self_attn.attn.\2_scale"),
|
rf".self_attn.{attn_str}.\2_scale"),
|
||||||
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
|
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
|
||||||
# .self_attn.attn.{k,v}_scale
|
# .self_attn.attn.{k,v}_scale
|
||||||
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
|
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
|
||||||
|
Reference in New Issue
Block a user