[Quantization] Add compressed-tensors NVFP4 MoE Support (#19990)

Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
Dipika Sikka
2025-06-30 00:05:40 +02:00
committed by GitHub
parent 7b1895e6ce
commit 6f2f53a82d
6 changed files with 295 additions and 22 deletions

View File

@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsWNA16, cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
@ -668,8 +668,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
if isinstance(qkv_proj.scheme, scheme) or isinstance(
qkv_proj.scheme, CompressedTensorsW4A16Fp4
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
qkv_proj.scheme,
CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported():
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")

View File

@ -1246,6 +1246,7 @@ class FusedMoE(torch.nn.Module):
param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
@ -1273,6 +1274,7 @@ class FusedMoE(torch.nn.Module):
tp_rank=self.tp_rank)
return True if return_success else None
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
@ -1289,7 +1291,7 @@ class FusedMoE(torch.nn.Module):
tp_rank=self.tp_rank)
return True if return_success else None
# Case weight scales, zero_points and offset
# Case weight scales, zero_points and offset, weight/input global scales
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme

View File

@ -33,6 +33,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
cutlass_fp4_supported)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -375,7 +377,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(
if cutlass_fp4_supported(
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:

View File

@ -21,8 +21,12 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
@ -46,12 +50,11 @@ class GPTQMarlinState(Enum):
__all__ = [
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod"
]
@ -84,6 +87,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@ -95,6 +100,268 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self):
self.use_marlin = not cutlass_fp4_supported()
self.group_size = 16
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
layer.num_experts = num_experts
layer.params_dtype = params_dtype
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Weight Scales
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
# Input Global Scales
w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_global_scale", w2_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# From packed to weight
layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
requires_grad=False)
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
layer.w13_weight_global_scale[:, 1]):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected.")
# Take inverse of global scale saved to disk
layer.w13_weight_scale_2 = torch.nn.Parameter(
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)
layer.w2_weight_scale_2 = torch.nn.Parameter(
1 / layer.w2_weight_global_scale.data, requires_grad=False)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
return
# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w13_weight_scale),
requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w2_weight_scale),
requires_grad=False)
# w13
w13_input_global_scale = layer.w13_input_global_scale.max(
dim=1).values.to(torch.float32)
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False)
layer.w13_input_scale_quant = torch.nn.Parameter(
(w13_input_global_scale), requires_grad=False)
# w2
layer.g2_alphas = torch.nn.Parameter(
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
torch.float32),
requires_grad=False)
layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for CompressedTensorsW4A4MoeMethod.")
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device).to(x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(

View File

@ -5,8 +5,7 @@ import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
@ -15,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -33,15 +31,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return 80
return 100
@classmethod
def cutlass_fp4_supported(cls) -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
)
return cutlass_scaled_mm_supports_fp4(capability)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,

View File

@ -2,9 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]
__all__ = [
"break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
"cutlass_fp4_supported"
]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
@ -12,6 +17,14 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape