mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ModelOpt] Remove NVFP4 MoE K%16==0 constraint (#26891)
Signed-off-by: XiaobingSuper <xiaobingzhangupc@gmail.com>
This commit is contained in:
@ -1542,23 +1542,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
del layer.w2_input_scale_quant
|
||||
else:
|
||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||
assert layer.w13_weight_scale.shape[2] % 16 == 0, (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16"
|
||||
)
|
||||
assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
|
||||
"Weight Blockscale must be represented as FP8-E4M3"
|
||||
)
|
||||
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
assert layer.w2_weight_scale.shape[2] % 16 == 0, (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16"
|
||||
)
|
||||
assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
|
||||
"Weight Blockscale must be represented as FP8-E4M3"
|
||||
)
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
|
Reference in New Issue
Block a user