mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Quantization] Add compressed-tensors NVFP4 MoE Support (#19990)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsWNA16, cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
@ -668,8 +668,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
assert isinstance(qkv_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
if isinstance(qkv_proj.scheme, scheme) or isinstance(
|
||||
qkv_proj.scheme, CompressedTensorsW4A16Fp4
|
||||
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
|
||||
qkv_proj.scheme,
|
||||
CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported():
|
||||
assert True
|
||||
else:
|
||||
raise AssertionError("FP4 Scheme Mismatch")
|
||||
|
@ -1246,6 +1246,7 @@ class FusedMoE(torch.nn.Module):
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
expert_data = param.data if full_load else param.data[expert_id]
|
||||
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
@ -1273,6 +1274,7 @@ class FusedMoE(torch.nn.Module):
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
||||
if "ModelOpt" in quant_method_name:
|
||||
if ('weight_scale_2' in weight_name
|
||||
or 'input_scale' in weight_name):
|
||||
@ -1289,7 +1291,7 @@ class FusedMoE(torch.nn.Module):
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
# Case weight scales, zero_points and offset
|
||||
# Case weight scales, zero_points and offset, weight/input global scales
|
||||
if ("scale" in weight_name or "zero" in weight_name
|
||||
or "offset" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
|
@ -33,6 +33,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
cutlass_fp4_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -375,7 +377,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
if is_activation_quantization_format(self.quant_format):
|
||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(
|
||||
if cutlass_fp4_supported(
|
||||
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
else:
|
||||
|
@ -21,8 +21,12 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -46,12 +50,11 @@ class GPTQMarlinState(Enum):
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsMoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoECutlassMethod",
|
||||
"CompressedTensorsW8A8Int8MoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsW4A4MoeMethod"
|
||||
]
|
||||
|
||||
|
||||
@ -84,6 +87,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
else:
|
||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4MoeMethod()
|
||||
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
@ -95,6 +100,268 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
|
||||
|
||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(self):
|
||||
self.use_marlin = not cutlass_fp4_supported()
|
||||
self.group_size = 16
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
requires_grad=False,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Weight Scales
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.group_size,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# Weight Global Scales
|
||||
w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
|
||||
|
||||
# Input Global Scales
|
||||
w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_global_scale", w13_input_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_global_scale", w2_input_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# From packed to weight
|
||||
layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
|
||||
layer.w13_weight_global_scale[:, 1]):
|
||||
logger.warning_once(
|
||||
"w1_weight_global_scale must match w3_weight_global_scale. "
|
||||
"Accuracy may be affected.")
|
||||
|
||||
# Take inverse of global scale saved to disk
|
||||
layer.w13_weight_scale_2 = torch.nn.Parameter(
|
||||
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)
|
||||
|
||||
layer.w2_weight_scale_2 = torch.nn.Parameter(
|
||||
1 / layer.w2_weight_global_scale.data, requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
return
|
||||
|
||||
# swizzle weight scales
|
||||
layer.w13_blockscale_swizzled = torch.nn.Parameter(
|
||||
self.swizzle_blockscale(layer.w13_weight_scale),
|
||||
requires_grad=False)
|
||||
|
||||
layer.w2_blockscale_swizzled = torch.nn.Parameter(
|
||||
self.swizzle_blockscale(layer.w2_weight_scale),
|
||||
requires_grad=False)
|
||||
|
||||
# w13
|
||||
w13_input_global_scale = layer.w13_input_global_scale.max(
|
||||
dim=1).values.to(torch.float32)
|
||||
|
||||
layer.g1_alphas = torch.nn.Parameter(
|
||||
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
|
||||
requires_grad=False)
|
||||
|
||||
layer.w13_input_scale_quant = torch.nn.Parameter(
|
||||
(w13_input_global_scale), requires_grad=False)
|
||||
|
||||
# w2
|
||||
layer.g2_alphas = torch.nn.Parameter(
|
||||
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
|
||||
torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||
(layer.w2_input_global_scale), requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for "
|
||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
"supported for CompressedTensorsW4A4MoeMethod.")
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4MoeMethod.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device).to(x.dtype)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
|
@ -5,8 +5,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
@ -15,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -33,15 +31,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
return 80
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def cutlass_fp4_supported(cls) -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
|
||||
)
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
|
@ -2,9 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]
|
||||
__all__ = [
|
||||
"break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
|
||||
"cutlass_fp4_supported"
|
||||
]
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
|
||||
@ -12,6 +17,14 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
Reference in New Issue
Block a user