mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Kernels] Support blocked fp8 quantization for compressed tensors MoE (#25219)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -13,6 +13,7 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
@ -31,6 +32,9 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
|
||||
requant_weight_ue8m0_inplace)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
@ -45,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -505,10 +510,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
if not (per_tensor or per_channel):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, we require per tensor "
|
||||
"or channelwise, dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
assert self.weight_quant.dynamic is not None
|
||||
else:
|
||||
self.weight_block_size = None
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales and per_channel:
|
||||
@ -519,7 +526,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
and not self.block_quant)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
@ -531,8 +539,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# cutlass path
|
||||
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
|
||||
self.weight_quant, self.input_quant)
|
||||
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
|
||||
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
|
||||
self.use_cutlass = not self.block_quant and (
|
||||
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
|
||||
or self.is_fp8_w8a8_sm100)
|
||||
self.disable_expert_map = False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@ -547,6 +556,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.weight_block_size[0],
|
||||
self.weight_block_size[1],
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1
|
||||
and intermediate_size_per_partition % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
@ -602,6 +636,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
2 *
|
||||
((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, (hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
@ -706,6 +761,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
del layer.w2_input_scale
|
||||
|
||||
if self.use_cutlass:
|
||||
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
|
||||
device = layer.w13_weight.device
|
||||
# ab_strides1 and c_strides2 are the same
|
||||
self.ab_strides1_c_strides2 = torch.full(
|
||||
@ -724,6 +780,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if is_deep_gemm_e8m0_used() and self.block_quant:
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w13_weight.data,
|
||||
layer.w13_weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w2_weight.data,
|
||||
layer.w2_weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale):
|
||||
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale):
|
||||
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
@ -777,9 +856,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
return experts
|
||||
|
||||
# triton path
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||
|
||||
@ -790,14 +870,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonExperts(
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts(%s)", self.__class__.__name__)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
logger.debug("TritonOrDeepGemmExperts(%s)",
|
||||
self.__class__.__name__)
|
||||
return TritonOrDeepGemmExperts(self.moe_quant_config,
|
||||
allow_deep_gemm=True)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
@ -816,6 +898,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_channel_quant,
|
||||
block_shape=layer.weight_block_size,
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
@ -33,10 +33,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape)
|
||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||
get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
|
||||
requant_weight_ue8m0_inplace, validate_fp8_block_shape)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
@ -64,12 +64,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@ -660,10 +654,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = \
|
||||
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
|
||||
|
||||
@ -811,10 +805,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv)
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv)
|
||||
|
||||
|
@ -1014,3 +1014,9 @@ def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module(
|
||||
"""
|
||||
assert isinstance(m, FusedMoE)
|
||||
w13 = m.w13_weight
|
||||
w13_s = m.w13_weight_scale_inv
|
||||
w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale)
|
||||
w2 = m.w2_weight
|
||||
w2_s = m.w2_weight_scale_inv
|
||||
w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale)
|
||||
num_topk = m.top_k
|
||||
|
||||
assert isinstance(w13, torch.Tensor)
|
||||
|
Reference in New Issue
Block a user