From b3dda72c236f42fc8a2a7bd2003e0d394533bccd Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:46:48 -0400 Subject: [PATCH] [Feature] Migrate DeepGEMM API from `get_m_alignment_for_contiguous_layout` to `get_mk_alignment_for_contiguous_layout` (#26935) Signed-off-by: yewentao256 Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/kernels/moe/test_block_fp8.py | 11 +++++----- .../layers/fused_moe/batched_deep_gemm_moe.py | 9 +++++--- .../batched_triton_or_deep_gemm_moe.py | 4 ++-- .../layers/fused_moe/deep_gemm_moe.py | 18 +++++++++------- .../layers/fused_moe/deep_gemm_utils.py | 15 ++----------- .../layers/fused_moe/triton_deep_gemm_moe.py | 8 ++++--- .../model_executor/warmup/deep_gemm_warmup.py | 21 ++++++++++--------- vllm/utils/deep_gemm.py | 17 ++++++++++++++- 8 files changed, 57 insertions(+), 46 deletions(-) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index b8cd3cb920..11b1e2ff3c 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -22,13 +22,13 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, +) dg_available = has_deep_gemm() -if dg_available: - from deep_gemm import get_m_alignment_for_contiguous_layout - if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) @@ -218,8 +218,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) torch.manual_seed(seed) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) - block_m = get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] + block_size = get_mk_alignment_for_contiguous_layout() dtype = torch.bfloat16 a = torch.randn((M, K), dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 91ce7e3019..095ec966ea 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -6,14 +6,17 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import ( + fp8_m_grouped_gemm_nt_masked, + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, +) logger = init_logger(__name__) @@ -227,7 +230,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): quant_config: Quantization configuration """ super().__init__(quant_config) - assert self.block_shape == deep_gemm_block_shape() + assert self.block_shape == get_mk_alignment_for_contiguous_layout() self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 1b1af351a4..e69e9fd307 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -8,8 +8,8 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts +from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -31,7 +31,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): self.allow_deep_gemm = ( allow_deep_gemm and self.quant_config.use_fp8_w8a8 - and self.block_shape == deep_gemm_block_shape() + and self.block_shape == get_mk_alignment_for_contiguous_layout() ) self.batched_deep_gemm_experts = ( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index c944140877..71776c654b 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, - deep_gemm_block_shape, deepgemm_moe_permute, deepgemm_unpermute_and_reduce, ) @@ -28,14 +27,17 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + m_grouped_fp8_gemm_nt_contiguous, +) from vllm.utils.functools import run_once logger = init_logger(__name__) def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: - align = deep_gemm_block_shape()[0] + align = get_mk_alignment_for_contiguous_layout()[0] return align <= M and N % align == 0 and K % align == 0 @@ -54,7 +56,7 @@ def _valid_deep_gemm( M = hidden_states.size(0) _, K, N = w2.size() - align = deep_gemm_block_shape()[0] + align = get_mk_alignment_for_contiguous_layout()[0] if not _valid_deep_gemm_shape(M, N, K): logger.debug_once( @@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels( assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - block_m = deep_gemm_block_shape()[0] + block_m = get_mk_alignment_for_contiguous_layout()[0] num_experts = w1.size(0) device = w1.device @@ -173,7 +175,7 @@ def warmup_deepgemm_gg_contiguous_kernels( class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) - assert quant_config.block_shape == deep_gemm_block_shape() + assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout() assert quant_config.quant_dtype == torch.float8_e4m3fn assert not quant_config.per_act_token_quant assert not quant_config.per_out_ch_quant @@ -255,7 +257,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): M=topk_ids.size(0), num_topk=topk_ids.size(1), local_num_experts=local_num_experts, - alignment=deep_gemm_block_shape()[0], + alignment=get_mk_alignment_for_contiguous_layout()[0], expert_tokens_meta=expert_tokens_meta, ) @@ -364,7 +366,7 @@ def deep_gemm_moe_fp8( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=deep_gemm_block_shape(), + block_shape=get_mk_alignment_for_contiguous_layout(), ) fn = mk.FusedMoEModularKernel( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index 570c5ec09d..85294f6aea 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -5,23 +5,13 @@ Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00b and updated to fit vllm needs and terminology. """ -import functools - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens from vllm.triton_utils import tl, triton from vllm.utils import round_up - - -@functools.cache -def deep_gemm_block_shape() -> list[int]: - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - block = dg.get_m_alignment_for_contiguous_layout() - return [block, block] +from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout def expert_num_tokens_round_up_and_sum( @@ -354,8 +344,7 @@ def deepgemm_moe_permute( H = aq.size(1) device = aq.device - block_m = deep_gemm_block_shape()[0] - block_k = deep_gemm_block_shape()[1] + block_m, block_k = get_mk_alignment_for_contiguous_layout() M_sum = compute_aligned_M( M=topk_ids.size(0), diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 908b1806ac..b8e0837162 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -10,9 +10,11 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, _valid_deep_gemm_shape, ) -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, +) class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -28,7 +30,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): self.allow_deep_gemm = ( allow_deep_gemm and self.quant_config.use_fp8_w8a8 - and self.block_shape == deep_gemm_block_shape() + and self.block_shape == get_mk_alignment_for_contiguous_layout() ) self.deep_gemm_expert = ( diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index f1ed2696a0..78cbcd8e54 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -12,10 +12,7 @@ from tqdm import tqdm import vllm.envs as envs from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, - deep_gemm_block_shape, -) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( @@ -23,7 +20,11 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( ) from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod -from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + get_mk_alignment_for_contiguous_layout, + m_grouped_fp8_gemm_nt_contiguous, +) def _generate_optimal_warmup_m_values( @@ -129,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: """ Return True if the input module/layer could be processed with DeepGEMM. """ - block_size = deep_gemm_block_shape()[0] + block_size = get_mk_alignment_for_contiguous_layout()[0] if not ( isinstance(module, LinearBase) and isinstance(module.quant_method, Fp8LinearMethod) @@ -139,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: w, _, block_sizes = _extract_data_from_linear_base_module(module) return ( - block_sizes == deep_gemm_block_shape() + block_sizes == get_mk_alignment_for_contiguous_layout() and w.ndim == 2 and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0 @@ -155,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: if ( moe_quant_config is None or moe_quant_config.quant_dtype != torch.float8_e4m3fn - or moe_quant_config.block_shape != deep_gemm_block_shape() + or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout() ): return False @@ -176,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: return n, k = w.size() - block_m = deep_gemm_block_shape()[0] + block_m = get_mk_alignment_for_contiguous_layout()[0] device = w.device a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn) @@ -229,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - block_m = deep_gemm_block_shape()[0] + block_m = get_mk_alignment_for_contiguous_layout()[0] num_experts = w1.size(0) device = w1.device diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 39ffba3137..6c69e3fce7 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -75,6 +75,7 @@ _fp8_mqa_logits_impl: Callable[..., Any] | None = None _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None _get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None +_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None def _lazy_init() -> None: @@ -83,7 +84,7 @@ def _lazy_init() -> None: global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl global _get_paged_mqa_logits_metadata_impl global _get_mn_major_tma_aligned_tensor_impl - + global _get_mk_alignment_for_contiguous_layout_impl # fast path if ( _fp8_gemm_nt_impl is not None @@ -92,6 +93,7 @@ def _lazy_init() -> None: or _fp8_mqa_logits_impl is not None or _fp8_paged_mqa_logits_impl is not None or _get_paged_mqa_logits_metadata_impl is not None + or _get_mk_alignment_for_contiguous_layout_impl is not None ): return @@ -118,6 +120,9 @@ def _lazy_init() -> None: _get_mn_major_tma_aligned_tensor_impl = getattr( _dg, "get_mn_major_tma_aligned_tensor", None ) + _get_mk_alignment_for_contiguous_layout_impl = getattr( + _dg, "get_mk_alignment_for_contiguous_layout", None + ) def get_num_sms() -> int: @@ -126,6 +131,15 @@ def get_num_sms() -> int: return int(_dg.get_num_sms()) +@functools.cache +def get_mk_alignment_for_contiguous_layout() -> list[int]: + _lazy_init() + if _get_mk_alignment_for_contiguous_layout_impl is None: + return _missing() + mk_align_size = _get_mk_alignment_for_contiguous_layout_impl() + return [mk_align_size, mk_align_size] + + def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor""" _lazy_init() @@ -338,4 +352,5 @@ __all__ = [ "get_num_sms", "should_use_deepgemm_for_fp8_linear", "get_col_major_tma_aligned_tensor", + "get_mk_alignment_for_contiguous_layout", ]