[Misc] DeepGemmExperts : Avoid JIT generation in the hot-path (#21955)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-08-02 08:12:03 +05:30
committed by GitHub
parent 57393715e8
commit a65f46be5e
3 changed files with 92 additions and 1 deletions

View File

@ -126,6 +126,7 @@ if TYPE_CHECKING:
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
@ -910,6 +911,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
# startup time by a couple of minutes.
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
"VLLM_SKIP_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),

View File

@ -4,7 +4,9 @@ import functools
from typing import Any, Optional
import torch
from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
@ -17,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm
from vllm.utils import has_deep_gemm, run_once
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__)
@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
return True
@run_once
def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int):
"""
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
input tensor shapes. In this function, we construct all possible input
tensor shapes so all the kernels are JIT'ed and cached.
Note that this warmup is expected to happen during the model profile
call and not during actual model inference.
"""
assert w1.size(0) == w2.size(0), (
"w1 and w2 must have the same number of experts")
block_m = deep_gemm_block_shape()[0]
num_experts = w1.size(0)
device = w1.device
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE,
num_topk,
num_experts,
block_m,
expert_tokens_meta=None)
# Distribute expert-ids evenly.
MAX_BLOCKS = MAX_M // block_m
expert_ids_block = torch.randint(low=0,
high=num_experts,
size=(MAX_BLOCKS, ),
device=device,
dtype=torch.int32)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
a1q_scales = torch.empty((MAX_M, k // block_m),
device=device,
dtype=torch.float32)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(total=MAX_BLOCKS,
desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})")
num_tokens = MAX_M
while num_tokens > 0:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
out[:num_tokens], expert_ids[:num_tokens])
pbar.update(1)
num_tokens = num_tokens - block_m
_warmup(w1, w1_scale)
_warmup(w2, w2_scale)
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
@ -156,6 +217,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
assert self.block_shape is not None
assert a1q_scale is not None
assert w1_scale is not None
assert w2_scale is not None
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
# to happen during actual model-inference. The
# `warmup_deepgemm_kernels` function is a `run_once` decorated
# function that executes during the model profile run. This warmup
# should create all the required JITs for the current model.
warmup_deepgemm_gg_contiguous_kernels(w1,
w2,
w1_scale,
w2_scale,
num_topk=topk_ids.size(1))
a1q = hidden_states
_, N, K = w1.size()

View File

@ -8,6 +8,7 @@ from __future__ import annotations
import functools
import importlib
import os
from typing import Any, Callable, NoReturn
import torch
@ -77,6 +78,12 @@ def _lazy_init() -> None:
if not has_deep_gemm():
return
# Set up deep_gemm cache path
DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
envs.VLLM_CACHE_ROOT, "deep_gemm")
_dg = importlib.import_module("deep_gemm")
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",