mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] DeepGemmExperts : Avoid JIT generation in the hot-path (#21955)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
57393715e8
commit
a65f46be5e
@ -126,6 +126,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_TPU_USING_PATHWAYS: bool = False
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
@ -910,6 +911,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
|
||||
# JIT all the required kernels before model execution so there is no
|
||||
# JIT'ing in the hot-path. However, this warmup increases the engine
|
||||
# startup time by a couple of minutes.
|
||||
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
|
||||
"VLLM_SKIP_DEEP_GEMM_WARMUP":
|
||||
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
|
||||
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
|
||||
|
@ -4,7 +4,9 @@ import functools
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as env
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
@ -17,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils import has_deep_gemm, run_once
|
||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
return True
|
||||
|
||||
|
||||
@run_once
|
||||
def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int):
|
||||
"""
|
||||
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
|
||||
input tensor shapes. In this function, we construct all possible input
|
||||
tensor shapes so all the kernels are JIT'ed and cached.
|
||||
Note that this warmup is expected to happen during the model profile
|
||||
call and not during actual model inference.
|
||||
"""
|
||||
|
||||
assert w1.size(0) == w2.size(0), (
|
||||
"w1 and w2 must have the same number of experts")
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
# This is the maximum GroupedGemm M size that we expect to run
|
||||
# the grouped_gemm with.
|
||||
MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
num_topk,
|
||||
num_experts,
|
||||
block_m,
|
||||
expert_tokens_meta=None)
|
||||
# Distribute expert-ids evenly.
|
||||
MAX_BLOCKS = MAX_M // block_m
|
||||
expert_ids_block = torch.randint(low=0,
|
||||
high=num_experts,
|
||||
size=(MAX_BLOCKS, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
||||
|
||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
||||
|
||||
_, n, k = w.size()
|
||||
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
|
||||
a1q_scales = torch.empty((MAX_M, k // block_m),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(total=MAX_BLOCKS,
|
||||
desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})")
|
||||
num_tokens = MAX_M
|
||||
while num_tokens > 0:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
|
||||
out[:num_tokens], expert_ids[:num_tokens])
|
||||
pbar.update(1)
|
||||
num_tokens = num_tokens - block_m
|
||||
|
||||
_warmup(w1, w1_scale)
|
||||
_warmup(w2, w2_scale)
|
||||
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self):
|
||||
@ -156,6 +217,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
assert self.block_shape is not None
|
||||
assert a1q_scale is not None
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
|
||||
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
|
||||
# to happen during actual model-inference. The
|
||||
# `warmup_deepgemm_kernels` function is a `run_once` decorated
|
||||
# function that executes during the model profile run. This warmup
|
||||
# should create all the required JITs for the current model.
|
||||
warmup_deepgemm_gg_contiguous_kernels(w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
num_topk=topk_ids.size(1))
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any, Callable, NoReturn
|
||||
|
||||
import torch
|
||||
@ -77,6 +78,12 @@ def _lazy_init() -> None:
|
||||
if not has_deep_gemm():
|
||||
return
|
||||
|
||||
# Set up deep_gemm cache path
|
||||
DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
|
||||
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
|
||||
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "deep_gemm")
|
||||
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
|
||||
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
|
||||
|
Reference in New Issue
Block a user