mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Feature] Migrate DeepGEMM API from get_m_alignment_for_contiguous_layout
to get_mk_alignment_for_contiguous_layout
(#26935)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -22,13 +22,13 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
if dg_available:
|
||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
@ -218,8 +218,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
block_m = get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
block_size = get_mk_alignment_for_contiguous_layout()
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
@ -6,14 +6,17 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -227,7 +230,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
quant_config: Quantization configuration
|
||||
"""
|
||||
super().__init__(quant_config)
|
||||
assert self.block_shape == deep_gemm_block_shape()
|
||||
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
|
@ -8,8 +8,8 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -31,7 +31,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == deep_gemm_block_shape()
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
self.batched_deep_gemm_experts = (
|
||||
|
@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M,
|
||||
deep_gemm_block_shape,
|
||||
deepgemm_moe_permute,
|
||||
deepgemm_unpermute_and_reduce,
|
||||
)
|
||||
@ -28,14 +27,17 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
from vllm.utils.functools import run_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
||||
align = deep_gemm_block_shape()[0]
|
||||
align = get_mk_alignment_for_contiguous_layout()[0]
|
||||
return align <= M and N % align == 0 and K % align == 0
|
||||
|
||||
|
||||
@ -54,7 +56,7 @@ def _valid_deep_gemm(
|
||||
M = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
|
||||
align = deep_gemm_block_shape()[0]
|
||||
align = get_mk_alignment_for_contiguous_layout()[0]
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
logger.debug_once(
|
||||
@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
|
||||
|
||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
@ -173,7 +175,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.block_shape == deep_gemm_block_shape()
|
||||
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
assert quant_config.quant_dtype == torch.float8_e4m3fn
|
||||
assert not quant_config.per_act_token_quant
|
||||
assert not quant_config.per_out_ch_quant
|
||||
@ -255,7 +257,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
M=topk_ids.size(0),
|
||||
num_topk=topk_ids.size(1),
|
||||
local_num_experts=local_num_experts,
|
||||
alignment=deep_gemm_block_shape()[0],
|
||||
alignment=get_mk_alignment_for_contiguous_layout()[0],
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
@ -364,7 +366,7 @@ def deep_gemm_moe_fp8(
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=deep_gemm_block_shape(),
|
||||
block_shape=get_mk_alignment_for_contiguous_layout(),
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
|
@ -5,23 +5,13 @@ Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00b
|
||||
and updated to fit vllm needs and terminology.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
@functools.cache
|
||||
def deep_gemm_block_shape() -> list[int]:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
block = dg.get_m_alignment_for_contiguous_layout()
|
||||
return [block, block]
|
||||
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
def expert_num_tokens_round_up_and_sum(
|
||||
@ -354,8 +344,7 @@ def deepgemm_moe_permute(
|
||||
H = aq.size(1)
|
||||
device = aq.device
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_k = deep_gemm_block_shape()[1]
|
||||
block_m, block_k = get_mk_alignment_for_contiguous_layout()
|
||||
|
||||
M_sum = compute_aligned_M(
|
||||
M=topk_ids.size(0),
|
||||
|
@ -10,9 +10,11 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm,
|
||||
_valid_deep_gemm_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -28,7 +30,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == deep_gemm_block_shape()
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
self.deep_gemm_expert = (
|
||||
|
@ -12,10 +12,7 @@ from tqdm import tqdm
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M,
|
||||
deep_gemm_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
@ -23,7 +20,11 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
|
||||
|
||||
def _generate_optimal_warmup_m_values(
|
||||
@ -129,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Return True if the input module/layer could be processed with DeepGEMM.
|
||||
"""
|
||||
block_size = deep_gemm_block_shape()[0]
|
||||
block_size = get_mk_alignment_for_contiguous_layout()[0]
|
||||
if not (
|
||||
isinstance(module, LinearBase)
|
||||
and isinstance(module.quant_method, Fp8LinearMethod)
|
||||
@ -139,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
|
||||
w, _, block_sizes = _extract_data_from_linear_base_module(module)
|
||||
return (
|
||||
block_sizes == deep_gemm_block_shape()
|
||||
block_sizes == get_mk_alignment_for_contiguous_layout()
|
||||
and w.ndim == 2
|
||||
and w.shape[0] % block_size == 0
|
||||
and w.shape[1] % block_size == 0
|
||||
@ -155,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
if (
|
||||
moe_quant_config is None
|
||||
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
||||
or moe_quant_config.block_shape != deep_gemm_block_shape()
|
||||
or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout()
|
||||
):
|
||||
return False
|
||||
|
||||
@ -176,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
||||
return
|
||||
|
||||
n, k = w.size()
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
|
||||
device = w.device
|
||||
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
|
||||
@ -229,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
|
||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
|
@ -75,6 +75,7 @@ _fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
@ -83,7 +84,7 @@ def _lazy_init() -> None:
|
||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
|
||||
global _get_mk_alignment_for_contiguous_layout_impl
|
||||
# fast path
|
||||
if (
|
||||
_fp8_gemm_nt_impl is not None
|
||||
@ -92,6 +93,7 @@ def _lazy_init() -> None:
|
||||
or _fp8_mqa_logits_impl is not None
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None
|
||||
or _get_mk_alignment_for_contiguous_layout_impl is not None
|
||||
):
|
||||
return
|
||||
|
||||
@ -118,6 +120,9 @@ def _lazy_init() -> None:
|
||||
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||
_dg, "get_mn_major_tma_aligned_tensor", None
|
||||
)
|
||||
_get_mk_alignment_for_contiguous_layout_impl = getattr(
|
||||
_dg, "get_mk_alignment_for_contiguous_layout", None
|
||||
)
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
@ -126,6 +131,15 @@ def get_num_sms() -> int:
|
||||
return int(_dg.get_num_sms())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_mk_alignment_for_contiguous_layout() -> list[int]:
|
||||
_lazy_init()
|
||||
if _get_mk_alignment_for_contiguous_layout_impl is None:
|
||||
return _missing()
|
||||
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
|
||||
return [mk_align_size, mk_align_size]
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||
_lazy_init()
|
||||
@ -338,4 +352,5 @@ __all__ = [
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
"get_mk_alignment_for_contiguous_layout",
|
||||
]
|
||||
|
Reference in New Issue
Block a user