mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Kernels] Support blocked fp8 quantization for compressed tensors MoE (#25219)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
		| @ -13,6 +13,7 @@ from compressed_tensors.quantization import (ActivationOrdering, | ||||
| import vllm.envs as envs | ||||
| import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||||
| from vllm import _custom_ops as ops | ||||
| from vllm.distributed import get_tensor_model_parallel_world_size | ||||
| from vllm.logger import init_logger | ||||
| from vllm.model_executor.layers.fused_moe import ( | ||||
|     FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, | ||||
| @ -31,6 +32,9 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter | ||||
| from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( | ||||
|     build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, | ||||
|     select_nvfp4_gemm_impl) | ||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||
|     expert_weight_is_col_major, get_col_major_tma_aligned_tensor, | ||||
|     requant_weight_ue8m0_inplace) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | ||||
|     check_moe_marlin_supports_layer, marlin_make_workspace_new, | ||||
|     marlin_moe_permute_scales) | ||||
| @ -45,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | ||||
| from vllm.model_executor.utils import set_weight_attrs | ||||
| from vllm.platforms import current_platform | ||||
| from vllm.scalar_type import scalar_types | ||||
| from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| @ -505,10 +510,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|             self.weight_quant.strategy == QuantizationStrategy.CHANNEL | ||||
|             and self.input_quant.strategy == QuantizationStrategy.TOKEN) | ||||
|         if not (per_tensor or per_channel): | ||||
|             raise ValueError( | ||||
|                 "For FP8 Fused MoE layers, we require per tensor " | ||||
|                 "or channelwise, dynamic per token quantization. Found " | ||||
|                 f"{self.weight_quant}, {self.input_quant}") | ||||
|             assert self.weight_quant.strategy == QuantizationStrategy.BLOCK | ||||
|             self.weight_block_size = self.weight_quant.block_structure | ||||
|             assert self.weight_quant.dynamic is not None | ||||
|         else: | ||||
|             self.weight_block_size = None | ||||
|         self.block_quant = self.weight_block_size is not None | ||||
|  | ||||
|         self.static_input_scales = not self.input_quant.dynamic | ||||
|         if self.static_input_scales and per_channel: | ||||
| @ -519,7 +526,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|         # For GPUs that lack FP8 hardware support, we can leverage the Marlin | ||||
|         # kernel for fast weight-only FP8 quantization | ||||
|         self.use_marlin = (not current_platform.has_device_capability(89) | ||||
|                            or envs.VLLM_TEST_FORCE_FP8_MARLIN) | ||||
|                            or envs.VLLM_TEST_FORCE_FP8_MARLIN | ||||
|                            and not self.block_quant) | ||||
|         # Disable marlin for rocm | ||||
|         if current_platform.is_rocm(): | ||||
|             self.use_marlin = False | ||||
| @ -531,8 +539,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|         # cutlass path | ||||
|         self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( | ||||
|             self.weight_quant, self.input_quant) | ||||
|         self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( | ||||
|             self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) | ||||
|         self.use_cutlass = not self.block_quant and ( | ||||
|             quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) | ||||
|             or self.is_fp8_w8a8_sm100) | ||||
|         self.disable_expert_map = False | ||||
|  | ||||
|     def create_weights(self, layer: torch.nn.Module, num_experts: int, | ||||
| @ -547,6 +556,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|  | ||||
|         params_dtype = torch.float8_e4m3fn | ||||
|  | ||||
|         if self.block_quant: | ||||
|             assert self.weight_block_size is not None | ||||
|             layer.weight_block_size = self.weight_block_size | ||||
|             tp_size = get_tensor_model_parallel_world_size() | ||||
|             block_n, block_k = ( | ||||
|                 self.weight_block_size[0], | ||||
|                 self.weight_block_size[1], | ||||
|             ) | ||||
|             # NOTE: To ensure proper alignment of the block-wise quantization | ||||
|             # scales, the output_size of the weights for both the gate and up | ||||
|             # layers must be divisible by block_n. | ||||
|             # Required by column parallel or enabling merged weights | ||||
|             if intermediate_size_per_partition % block_n != 0: | ||||
|                 raise ValueError( | ||||
|                     f"The output_size of gate's and up's weight = " | ||||
|                     f"{intermediate_size_per_partition} is not divisible by " | ||||
|                     f"weight quantization block_n = {block_n}.") | ||||
|             if (tp_size > 1 | ||||
|                     and intermediate_size_per_partition % block_k != 0): | ||||
|                 # Required by row parallel | ||||
|                 raise ValueError( | ||||
|                     f"The input_size of down's weight = " | ||||
|                     f"{intermediate_size_per_partition} is not divisible by " | ||||
|                     f"weight quantization block_k = {block_k}.") | ||||
|  | ||||
|         # WEIGHTS | ||||
|         w13_weight = torch.nn.Parameter(torch.empty( | ||||
|             num_experts, | ||||
| @ -602,6 +636,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|             set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||||
|             set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||||
|  | ||||
|         elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: | ||||
|             w13_weight_scale = torch.nn.Parameter(torch.ones( | ||||
|                 num_experts, | ||||
|                 2 * | ||||
|                 ((intermediate_size_per_partition + block_n - 1) // block_n), | ||||
|                 (hidden_size + block_k - 1) // block_k, | ||||
|                 dtype=torch.float32), | ||||
|                                                   requires_grad=False) | ||||
|             layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||||
|             w2_weight_scale = torch.nn.Parameter(torch.ones( | ||||
|                 num_experts, (hidden_size + block_n - 1) // block_n, | ||||
|                 (intermediate_size_per_partition + block_k - 1) // block_k, | ||||
|                 dtype=torch.float32), | ||||
|                                                  requires_grad=False) | ||||
|             layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||||
|             # Add PER-CHANNEL quantization for FusedMoE.weight_loader. | ||||
|             extra_weight_attrs.update( | ||||
|                 {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) | ||||
|             set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||||
|             set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||||
|  | ||||
|         # INPUT_SCALES | ||||
|         if self.static_input_scales: | ||||
|             w13_input_scale = torch.nn.Parameter(torch.ones( | ||||
| @ -706,6 +761,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|             del layer.w2_input_scale | ||||
|  | ||||
|         if self.use_cutlass: | ||||
|             assert self.weight_quant.strategy != QuantizationStrategy.BLOCK | ||||
|             device = layer.w13_weight.device | ||||
|             # ab_strides1 and c_strides2 are the same | ||||
|             self.ab_strides1_c_strides2 = torch.full( | ||||
| @ -724,6 +780,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|                 device=device, | ||||
|                 dtype=torch.int64) | ||||
|  | ||||
|         if is_deep_gemm_e8m0_used() and self.block_quant: | ||||
|             assert layer.weight_block_size is not None | ||||
|             # Re-quantise the expert weights so their scales are UE8M0. | ||||
|             block_sz = tuple(layer.weight_block_size) | ||||
|             requant_weight_ue8m0_inplace( | ||||
|                 layer.w13_weight.data, | ||||
|                 layer.w13_weight_scale.data, | ||||
|                 block_sz, | ||||
|             ) | ||||
|             requant_weight_ue8m0_inplace( | ||||
|                 layer.w2_weight.data, | ||||
|                 layer.w2_weight_scale.data, | ||||
|                 block_sz, | ||||
|             ) | ||||
|  | ||||
|             # Ensure column-major TMA alignment expected by DeepGEMM. | ||||
|             if expert_weight_is_col_major(layer.w13_weight_scale): | ||||
|                 layer.w13_weight_scale = get_col_major_tma_aligned_tensor( | ||||
|                     layer.w13_weight_scale) | ||||
|             if expert_weight_is_col_major(layer.w2_weight_scale): | ||||
|                 layer.w2_weight_scale = get_col_major_tma_aligned_tensor( | ||||
|                     layer.w2_weight_scale) | ||||
|  | ||||
|     def maybe_make_prepare_finalize( | ||||
|             self) -> Optional[mk.FusedMoEPrepareAndFinalize]: | ||||
|         if self.use_marlin or self.rocm_aiter_moe_enabled: | ||||
| @ -777,9 +856,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|             return experts | ||||
|  | ||||
|         # triton path | ||||
|         from vllm.model_executor.layers.fused_moe import TritonExperts | ||||
|         from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( | ||||
|             BatchedTritonExperts) | ||||
|         from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (  # noqa: E501 | ||||
|             BatchedTritonOrDeepGemmExperts) | ||||
|         from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( | ||||
|             TritonOrDeepGemmExperts) | ||||
|  | ||||
|         assert not self.rocm_aiter_moe_enabled and not self.use_marlin | ||||
|  | ||||
| @ -790,14 +870,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|             assert max_num_tokens_per_rank is not None | ||||
|  | ||||
|             logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) | ||||
|             return BatchedTritonExperts( | ||||
|             return BatchedTritonOrDeepGemmExperts( | ||||
|                 max_num_tokens=max_num_tokens_per_rank, | ||||
|                 num_dispatchers=prepare_finalize.num_dispatchers(), | ||||
|                 quant_config=self.moe_quant_config, | ||||
|             ) | ||||
|         else: | ||||
|             logger.debug("TritonExperts(%s)", self.__class__.__name__) | ||||
|             return TritonExperts(self.moe_quant_config) | ||||
|             logger.debug("TritonOrDeepGemmExperts(%s)", | ||||
|                          self.__class__.__name__) | ||||
|             return TritonOrDeepGemmExperts(self.moe_quant_config, | ||||
|                                            allow_deep_gemm=True) | ||||
|  | ||||
|     def get_fused_moe_quant_config( | ||||
|             self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: | ||||
| @ -816,6 +898,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): | ||||
|             a2_scale=layer.w2_input_scale, | ||||
|             per_act_token_quant=per_act_token, | ||||
|             per_out_ch_quant=per_channel_quant, | ||||
|             block_shape=layer.weight_block_size, | ||||
|         ) | ||||
|  | ||||
|     def apply( | ||||
|  | ||||
| @ -33,10 +33,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( | ||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||
|     apply_fp8_block_linear, check_aiter_fp8_linear_support, | ||||
|     create_fp8_input_scale, create_fp8_scale_parameter, | ||||
|     create_fp8_weight_parameter, get_col_major_tma_aligned_tensor, | ||||
|     maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, | ||||
|     process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, | ||||
|     validate_fp8_block_shape) | ||||
|     create_fp8_weight_parameter, expert_weight_is_col_major, | ||||
|     get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block, | ||||
|     process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, | ||||
|     requant_weight_ue8m0_inplace, validate_fp8_block_shape) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( | ||||
|     apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, | ||||
|     prepare_moe_fp8_layer_for_marlin) | ||||
| @ -64,12 +64,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| def _is_col_major(x: torch.Tensor) -> bool: | ||||
|     assert x.dim() == 3 | ||||
|     b, m, n = x.shape | ||||
|     return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m | ||||
|  | ||||
|  | ||||
| class Fp8Config(QuantizationConfig): | ||||
|     """Config class for FP8.""" | ||||
|  | ||||
| @ -660,10 +654,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): | ||||
|             # DeepGemm scales need to be transposed and aligned. We try to do | ||||
|             # it ahead of time for performance reasons. | ||||
|             if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): | ||||
|                 if _is_col_major(layer.w13_weight_scale_inv): | ||||
|                 if expert_weight_is_col_major(layer.w13_weight_scale_inv): | ||||
|                     layer.w13_weight_scale_inv = \ | ||||
|                         get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) | ||||
|                 if _is_col_major(layer.w2_weight_scale_inv): | ||||
|                 if expert_weight_is_col_major(layer.w2_weight_scale_inv): | ||||
|                     layer.w2_weight_scale_inv = \ | ||||
|                         get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) | ||||
|  | ||||
| @ -811,10 +805,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): | ||||
|             ) | ||||
|  | ||||
|             # Ensure column-major TMA alignment expected by DeepGEMM. | ||||
|             if _is_col_major(layer.w13_weight_scale_inv): | ||||
|             if expert_weight_is_col_major(layer.w13_weight_scale_inv): | ||||
|                 layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( | ||||
|                     layer.w13_weight_scale_inv) | ||||
|             if _is_col_major(layer.w2_weight_scale_inv): | ||||
|             if expert_weight_is_col_major(layer.w2_weight_scale_inv): | ||||
|                 layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( | ||||
|                     layer.w2_weight_scale_inv) | ||||
|  | ||||
|  | ||||
| @ -1014,3 +1014,9 @@ def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, | ||||
|         cutlass_block_fp8_supported=cutlass_block_fp8_supported, | ||||
|         use_aiter_and_is_supported=use_aiter_and_is_supported, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def expert_weight_is_col_major(x: torch.Tensor) -> bool: | ||||
|     assert x.dim() == 3 | ||||
|     b, m, n = x.shape | ||||
|     return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m | ||||
|  | ||||
| @ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module( | ||||
|     """ | ||||
|     assert isinstance(m, FusedMoE) | ||||
|     w13 = m.w13_weight | ||||
|     w13_s = m.w13_weight_scale_inv | ||||
|     w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale) | ||||
|     w2 = m.w2_weight | ||||
|     w2_s = m.w2_weight_scale_inv | ||||
|     w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale) | ||||
|     num_topk = m.top_k | ||||
|  | ||||
|     assert isinstance(w13, torch.Tensor) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user