Signed-off-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Duncan Moss <dmoss@nvidia.com>
1735 lines
68 KiB
Python
1735 lines
68 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Fused MoE kernel."""
|
|
import functools
|
|
import json
|
|
import os
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
from vllm import _custom_ops as ops
|
|
from vllm.logger import init_logger
|
|
# yapf: disable
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FusedMoEQuantConfig, get_config_quant_dtype)
|
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
_valid_cutlass_block_scaled_grouped_gemm,
|
|
run_cutlass_block_scaled_fused_experts)
|
|
# yapf: enable
|
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
|
_valid_deep_gemm, deep_gemm_moe_fp8)
|
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
|
moe_align_block_size)
|
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
MoEPrepareAndFinalizeNoEP)
|
|
from vllm.model_executor.layers.fused_moe.utils import (
|
|
_resize_cache, moe_kernel_quantize_input)
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@triton.jit
|
|
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
|
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
|
compute_type):
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
|
None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
|
|
|
|
@triton.jit
|
|
def fused_moe_kernel_gptq_awq(
|
|
# Pointers to matrices
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
b_scale_ptr,
|
|
b_zp_ptr,
|
|
topk_weights_ptr,
|
|
sorted_token_ids_ptr,
|
|
expert_ids_ptr,
|
|
num_tokens_post_padded_ptr,
|
|
# Matrix dimensions
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
EM,
|
|
num_valid_tokens,
|
|
# The stride variables represent how much to increase the ptr by when
|
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
|
# how much to increase `a_ptr` by to get the element one row down
|
|
# (A has M rows).
|
|
stride_am,
|
|
stride_ak,
|
|
stride_be,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_bse,
|
|
stride_bsk,
|
|
stride_bsn,
|
|
stride_bze,
|
|
stride_bzk,
|
|
stride_bzn,
|
|
block_k_diviable: tl.constexpr,
|
|
group_size: tl.constexpr,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
compute_type: tl.constexpr,
|
|
has_zp: tl.constexpr,
|
|
use_int4_w4a16: tl.constexpr,
|
|
use_int8_w8a16: tl.constexpr):
|
|
"""
|
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
token and expert matrices.
|
|
|
|
Key Parameters:
|
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
|
be any shape representing batches and K is the feature dimension of
|
|
each token.
|
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
|
the number of experts, K is the input feature dimension, and N is
|
|
the output feature dimension.
|
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
|
total number of tokens post padding, topk is the number of times
|
|
each token is repeated, and N is the output feature dimension.
|
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
|
repeated topk times and arranged by the expert index they are
|
|
assigned to.
|
|
- expert_ids: A tensor containing the indices of the expert for each
|
|
block. It determines which expert matrix from B should be used for
|
|
each block in A.
|
|
This kernel performs the multiplication of a token by its corresponding
|
|
expert matrix as determined by `expert_ids`. The sorting of
|
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
|
multiplication across different blocks processed by the same expert.
|
|
"""
|
|
# -----------------------------------------------------------
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
# This is done in a grouped ordering to promote L2 data reuse.
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
# ----------------------------------------------------------
|
|
# Create pointers for the first blocks of A and B.
|
|
# We will advance this pointer as we move in the K direction
|
|
# and accumulate
|
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
|
return
|
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
|
tl.int64)
|
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
|
token_mask = offs_token < num_valid_tokens
|
|
|
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
if off_experts == -1:
|
|
# -----------------------------------------------------------
|
|
# Write back zeros to the output when the expert is not
|
|
# in the current expert parallel rank.
|
|
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
|
offs_token, token_mask, BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N, compute_type)
|
|
return
|
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N +
|
|
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
|
offs_k[None, :] * stride_ak)
|
|
|
|
if use_int4_w4a16:
|
|
b_ptrs = b_ptr + off_experts * stride_be + \
|
|
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \
|
|
stride_bn
|
|
b_shifter = (offs_k[:, None] % 2) * 4
|
|
elif use_int8_w8a16:
|
|
b_ptrs = b_ptr + off_experts * stride_be + \
|
|
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
|
|
|
if not has_zp and use_int4_w4a16:
|
|
b_zp_num = 8
|
|
if not has_zp and use_int8_w8a16:
|
|
b_zp_num = 128
|
|
elif has_zp and use_int4_w4a16:
|
|
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
|
|
|
# -----------------------------------------------------------
|
|
# Iterate to compute a block of the C matrix.
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
# of fp32 values for higher accuracy.
|
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
# Load the next block of A and B, generate a mask by checking the
|
|
# K dimension.
|
|
|
|
if not block_k_diviable:
|
|
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
|
k_other = 0.0
|
|
else:
|
|
k_mask = None
|
|
k_other = None
|
|
|
|
a = tl.load(a_ptrs,
|
|
mask=token_mask[:, None] &
|
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
other=0.0)
|
|
b = tl.load(b_ptrs)
|
|
if use_int4_w4a16:
|
|
b = (b >> b_shifter) & 0xF
|
|
|
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
|
offs_bn[None, :] * stride_bsn + \
|
|
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \
|
|
stride_bsk
|
|
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
|
b_scale = b_scale.to(tl.float32)
|
|
|
|
if has_zp and use_int4_w4a16:
|
|
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
|
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
|
(offs_bn[None, :] // 2) * stride_bzn + \
|
|
offs_k_true * stride_bzk
|
|
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
|
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
|
|
b_zp = b_zp.to(tl.float32)
|
|
elif has_zp and use_int8_w8a16:
|
|
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
|
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
|
offs_bn[None, :] * stride_bzn + \
|
|
offs_k_true * stride_bzk
|
|
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
|
b_zp = b_zp.to(tl.float32)
|
|
|
|
# We accumulate along the K dimension.
|
|
if has_zp:
|
|
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
|
else:
|
|
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
|
accumulator = tl.dot(a, b, acc=accumulator)
|
|
|
|
# Advance the ptrs to the next K block.
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
if use_int4_w4a16:
|
|
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
|
else:
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if MUL_ROUTED_WEIGHT:
|
|
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
|
mask=token_mask,
|
|
other=0)
|
|
accumulator = accumulator * moe_weight[:, None]
|
|
|
|
accumulator = accumulator.to(compute_type)
|
|
# -----------------------------------------------------------
|
|
# Write back the block of the output
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
|
None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
|
|
|
|
@triton.jit
|
|
def fused_moe_kernel(
|
|
# Pointers to matrices
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
a_scale_ptr,
|
|
b_scale_ptr,
|
|
topk_weights_ptr,
|
|
sorted_token_ids_ptr,
|
|
expert_ids_ptr,
|
|
num_tokens_post_padded_ptr,
|
|
# Matrix dimensions
|
|
N,
|
|
K,
|
|
EM,
|
|
num_valid_tokens,
|
|
# The stride variables represent how much to increase the ptr by when
|
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
|
# how much to increase `a_ptr` by to get the element one row down
|
|
# (A has M rows).
|
|
stride_am,
|
|
stride_ak,
|
|
stride_be,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_asm,
|
|
stride_ask,
|
|
stride_bse,
|
|
stride_bsk,
|
|
stride_bsn,
|
|
# Block size for block-wise quantization
|
|
group_n: tl.constexpr,
|
|
group_k: tl.constexpr,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
compute_type: tl.constexpr,
|
|
use_fp8_w8a8: tl.constexpr,
|
|
use_int8_w8a8: tl.constexpr,
|
|
use_int8_w8a16: tl.constexpr,
|
|
per_channel_quant: tl.constexpr,
|
|
):
|
|
"""
|
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
token and expert matrices.
|
|
|
|
Key Parameters:
|
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
|
be any shape representing batches and K is the feature dimension of
|
|
each token.
|
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
|
the number of experts, K is the input feature dimension, and N is
|
|
the output feature dimension.
|
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
|
total number of tokens post padding, topk is the number of times
|
|
each token is repeated, and N is the output feature dimension.
|
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
|
repeated topk times and arranged by the expert index they are
|
|
assigned to.
|
|
- expert_ids: A tensor containing the indices of the expert for each
|
|
block. It determines which expert matrix from B should be used for
|
|
each block in A.
|
|
This kernel performs the multiplication of a token by its corresponding
|
|
expert matrix as determined by `expert_ids`. The sorting of
|
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
|
multiplication across different blocks processed by the same expert.
|
|
"""
|
|
# -----------------------------------------------------------
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
# This is done in a grouped ordering to promote L2 data reuse.
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
# ----------------------------------------------------------
|
|
# Create pointers for the first blocks of A and B.
|
|
# We will advance this pointer as we move in the K direction
|
|
# and accumulate
|
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
|
return
|
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
|
tl.int64)
|
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
|
token_mask = offs_token < num_valid_tokens
|
|
|
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
if off_experts == -1:
|
|
# -----------------------------------------------------------
|
|
# Write back zeros to the output when the expert is not
|
|
# in the current expert parallel rank.
|
|
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
|
offs_token, token_mask, BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N, compute_type)
|
|
return
|
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N +
|
|
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
|
offs_k[None, :] * stride_ak)
|
|
|
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
|
offs_bn[None, :] * stride_bn)
|
|
if use_int8_w8a16:
|
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
|
None, :] * stride_bsn
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
|
|
if use_fp8_w8a8 or use_int8_w8a8:
|
|
# block-wise
|
|
if group_k > 0 and group_n > 0:
|
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
offs_bsn = offs_bn // group_n
|
|
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
|
|
offs_bsn * stride_bsn)
|
|
# channel-wise
|
|
elif per_channel_quant:
|
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
|
None, :] * stride_bsn
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
# Load per-token scale for activations
|
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:,
|
|
None]
|
|
# tensor-wise
|
|
else:
|
|
a_scale = tl.load(a_scale_ptr)
|
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
|
|
|
# -----------------------------------------------------------
|
|
# Iterate to compute a block of the C matrix.
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
# of fp32 values for higher accuracy.
|
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
# Load the next block of A and B, generate a mask by checking the
|
|
# K dimension.
|
|
a = tl.load(a_ptrs,
|
|
mask=token_mask[:, None] &
|
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
other=0.0)
|
|
b = tl.load(b_ptrs,
|
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
|
other=0.0)
|
|
# We accumulate along the K dimension.
|
|
if use_int8_w8a16:
|
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
k_start = k * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
|
mask=token_mask,
|
|
other=0.0)
|
|
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
|
|
|
accumulator += tl.dot(a, b) * a_scale[:,
|
|
None] * b_scale[None, :]
|
|
else:
|
|
if use_fp8_w8a8:
|
|
# acc used to enable fp8_fast_accum
|
|
accumulator = tl.dot(a, b, acc=accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b)
|
|
else:
|
|
accumulator += tl.dot(a, b)
|
|
# Advance the ptrs to the next K block.
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if MUL_ROUTED_WEIGHT:
|
|
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
|
mask=token_mask,
|
|
other=0)
|
|
accumulator = accumulator * moe_weight[:, None]
|
|
if use_int8_w8a16:
|
|
accumulator = (accumulator * b_scale).to(compute_type)
|
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
accumulator = accumulator.to(compute_type)
|
|
else:
|
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
|
else:
|
|
accumulator = accumulator.to(compute_type)
|
|
# -----------------------------------------------------------
|
|
# Write back the block of the output
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
|
None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
|
|
|
|
def invoke_fused_moe_kernel(A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
C: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
B_scale: Optional[torch.Tensor],
|
|
B_zp: Optional[torch.Tensor],
|
|
topk_weights: Optional[torch.Tensor],
|
|
sorted_token_ids: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
num_tokens_post_padded: torch.Tensor,
|
|
mul_routed_weight: bool,
|
|
top_k: int,
|
|
config: dict[str, Any],
|
|
compute_type: tl.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
use_int4_w4a16: bool,
|
|
per_channel_quant: bool,
|
|
block_shape: Optional[list[int]] = None) -> None:
|
|
assert topk_weights is not None or not mul_routed_weight
|
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
|
assert sorted_token_ids.stride(0) == 1
|
|
|
|
if use_fp8_w8a8 or use_int8_w8a8:
|
|
assert B_scale is not None
|
|
assert (block_shape is None
|
|
or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))
|
|
assert (block_shape is None
|
|
or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))
|
|
|
|
elif use_int8_w8a16 or use_int4_w4a16:
|
|
assert B_scale is not None
|
|
assert block_shape is None or block_shape[0] == 0
|
|
else:
|
|
assert A_scale is None
|
|
assert B_scale is None
|
|
|
|
M = A.size(0)
|
|
num_tokens = M * top_k
|
|
|
|
EM = sorted_token_ids.size(0)
|
|
if A.size(0) < config["BLOCK_SIZE_M"]:
|
|
# optimize for small batch_size.
|
|
# We assume that top_ids of each token is unique, so
|
|
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
|
# and we can skip some invalid blocks.
|
|
EM = min(sorted_token_ids.size(0),
|
|
A.size(0) * top_k * config['BLOCK_SIZE_M'])
|
|
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
|
B.size(1), META['BLOCK_SIZE_N']), )
|
|
|
|
if (use_int8_w8a16 or use_int4_w4a16) and \
|
|
block_shape is not None and block_shape[1] > 0:
|
|
assert B_scale is not None and B_scale.ndim == 3
|
|
assert B_zp is None or B_zp.ndim == 3
|
|
|
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
|
num_valid_tokens=num_tokens,
|
|
group_size=block_shape[1],
|
|
num_experts=B.size(0),
|
|
bit=4 if use_int4_w4a16 else 8)
|
|
config = config.copy()
|
|
config.update(
|
|
get_moe_wna16_block_config(config=config,
|
|
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
|
num_valid_tokens=num_tokens,
|
|
size_k=A.size(1),
|
|
size_n=B.size(1),
|
|
num_experts=B.size(1),
|
|
group_size=block_shape[1],
|
|
real_top_k=top_k,
|
|
block_size_m=config["BLOCK_SIZE_M"]))
|
|
|
|
if use_moe_wna16_cuda:
|
|
bit = 4 if use_int4_w4a16 else 8
|
|
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
|
topk_weights if mul_routed_weight else None,
|
|
sorted_token_ids, expert_ids,
|
|
num_tokens_post_padded, top_k,
|
|
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
|
|
config["BLOCK_SIZE_K"], bit)
|
|
return
|
|
|
|
fused_moe_kernel_gptq_awq[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
B_scale,
|
|
B_zp,
|
|
topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
B.size(1),
|
|
A.size(1),
|
|
EM,
|
|
num_tokens,
|
|
A.stride(0),
|
|
A.stride(1),
|
|
B.stride(0),
|
|
B.stride(2),
|
|
B.stride(1),
|
|
C.stride(1),
|
|
C.stride(2),
|
|
B_scale.stride(0),
|
|
B_scale.stride(2),
|
|
B_scale.stride(1),
|
|
B_zp.stride(0) if B_zp is not None else 0,
|
|
B_zp.stride(2) if B_zp is not None else 0,
|
|
B_zp.stride(1) if B_zp is not None else 0,
|
|
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
|
group_size=block_shape[1],
|
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
top_k=top_k,
|
|
compute_type=compute_type,
|
|
has_zp=B_zp is not None,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
**config,
|
|
)
|
|
else:
|
|
config = config.copy()
|
|
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
|
if block_shape is not None:
|
|
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0],
|
|
block_shape[1]))
|
|
fused_moe_kernel[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
A_scale,
|
|
B_scale,
|
|
topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
B.size(1),
|
|
B.size(2),
|
|
EM,
|
|
num_tokens,
|
|
A.stride(0),
|
|
A.stride(1),
|
|
B.stride(0),
|
|
B.stride(2),
|
|
B.stride(1),
|
|
C.stride(1),
|
|
C.stride(2),
|
|
A_scale.stride(0)
|
|
if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
A_scale.stride(1)
|
|
if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
B_scale.stride(0)
|
|
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
B_scale.stride(2)
|
|
if B_scale is not None and B_scale.ndim == 3 else 0,
|
|
B_scale.stride(1)
|
|
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
0 if block_shape is None else block_shape[0],
|
|
0 if block_shape is None else block_shape[1],
|
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
top_k=top_k,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
per_channel_quant=per_channel_quant,
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
**config,
|
|
)
|
|
|
|
|
|
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
|
def get_config_file_name(E: int,
|
|
N: int,
|
|
dtype: Optional[str],
|
|
block_shape: Optional[list[int]] = None) -> str:
|
|
device_name = current_platform.get_device_name().replace(" ", "_")
|
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
|
block_shape_selector = ("" if not block_shape or not all(block_shape) else
|
|
f",block_shape={block_shape}").replace(" ", "")
|
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
|
|
|
|
|
|
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
|
@functools.lru_cache
|
|
def get_moe_configs(
|
|
E: int,
|
|
N: int,
|
|
dtype: Optional[str],
|
|
block_n: Optional[int] = None,
|
|
block_k: Optional[int] = None,
|
|
) -> Optional[dict[int, Any]]:
|
|
"""
|
|
Return optimized configurations for the fused MoE kernel.
|
|
|
|
The return value will be a dictionary that maps an irregular grid of
|
|
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
|
kernel on a given batch size bs, the closest batch size in the grid should
|
|
be picked and the associated configuration chosen to invoke the kernel.
|
|
"""
|
|
|
|
# First look up if an optimized configuration is available in the configs
|
|
# directory
|
|
block_shape = [block_n, block_k] if block_n and block_k else None
|
|
json_file_name = get_config_file_name(E, N, dtype, block_shape)
|
|
|
|
config_file_path = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
|
if os.path.exists(config_file_path):
|
|
with open(config_file_path) as f:
|
|
logger.info("Using configuration from %s for MoE layer.",
|
|
config_file_path)
|
|
# If a configuration has been found, return it
|
|
return {int(key): val for key, val in json.load(f).items()}
|
|
|
|
# If no optimized configuration is available, we will use the default
|
|
# configuration
|
|
logger.warning(
|
|
("Using default MoE config. Performance might be sub-optimal! "
|
|
"Config file not found at %s"), config_file_path)
|
|
return None
|
|
|
|
|
|
def get_moe_wna16_block_config(config: dict[str,
|
|
int], use_moe_wna16_cuda: bool,
|
|
num_valid_tokens: int, size_k: int, size_n: int,
|
|
num_experts: int, group_size: int,
|
|
real_top_k: int, block_size_m: int):
|
|
if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
|
|
# optimal block config is set
|
|
return {}
|
|
if not use_moe_wna16_cuda:
|
|
# triton moe wna16 kernel
|
|
if num_valid_tokens // real_top_k == 1:
|
|
# if bs=1, use a smaller BLOCK_SIZE_N
|
|
return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
|
|
else:
|
|
return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
|
|
else:
|
|
# cuda moe wna16 kernel
|
|
# set default block_size 128, and increase them when num_blocks
|
|
# is too large.
|
|
block_size_n = 128
|
|
block_size_k = 128
|
|
if block_size_k <= group_size:
|
|
block_size_k = group_size
|
|
|
|
num_n_blocks = size_k // block_size_k
|
|
num_k_blocks = size_n // block_size_k
|
|
num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \
|
|
num_experts
|
|
if num_valid_tokens // real_top_k <= block_size_m:
|
|
num_m_blocks = min(num_m_blocks, num_valid_tokens)
|
|
num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
|
|
|
|
if size_k % 256 == 0 and num_blocks >= 256 and \
|
|
block_size_k < 256:
|
|
block_size_k = 256
|
|
num_blocks = num_blocks // (256 // block_size_k)
|
|
|
|
if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \
|
|
size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \
|
|
num_blocks >= 512:
|
|
block_size_k = block_size_k * 2
|
|
num_blocks = num_blocks // 2
|
|
|
|
if num_blocks > 1024:
|
|
block_size_n = 256
|
|
num_n_blocks = num_n_blocks // 2
|
|
num_blocks = num_blocks // 2
|
|
|
|
if size_n <= 1024 and num_blocks >= 1024:
|
|
# The kernel performance got much better with BLOCK_SIZE_N=1024
|
|
# when num_blocks is large, event when N is small.
|
|
# Not sure why, maybe it force the CUDA SM process only one block
|
|
# at the same time.
|
|
block_size_n = 1024
|
|
|
|
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
|
|
|
|
|
|
def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int,
|
|
num_experts: int, bit: int):
|
|
return bit == 4 and group_size in [32, 64, 128] and \
|
|
num_valid_tokens / num_experts <= 6
|
|
|
|
|
|
def get_default_config(
|
|
M: int,
|
|
E: int,
|
|
N: int,
|
|
K: int,
|
|
topk: int,
|
|
dtype: Optional[str],
|
|
is_marlin: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> dict[str, int]:
|
|
if dtype == "fp8_w8a8" and block_shape is not None:
|
|
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
|
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
|
# num_stages=3 can cause triton.runtime.errors.OutOfResources
|
|
# on ROCm, set it to 2 instead.
|
|
config = {
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_N": block_shape[0],
|
|
"BLOCK_SIZE_K": block_shape[1],
|
|
"GROUP_SIZE_M": 32,
|
|
"num_warps": 4,
|
|
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
|
}
|
|
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
|
# moe wna16 kernels
|
|
# only set BLOCK_SIZE_M
|
|
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
|
|
bit = 4 if dtype == "int4_w4a16" else 8
|
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
|
|
block_shape[1], E, bit)
|
|
if use_moe_wna16_cuda:
|
|
config = {"BLOCK_SIZE_M": min(16, M)}
|
|
elif M <= 20:
|
|
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
|
elif M <= 40:
|
|
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
|
else:
|
|
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
|
elif is_marlin:
|
|
for block_size_m in [8, 16, 32, 48, 64]:
|
|
if M * topk / E / block_size_m < 0.9:
|
|
break
|
|
return {"BLOCK_SIZE_M": block_size_m}
|
|
elif M <= E:
|
|
config = {
|
|
"BLOCK_SIZE_M": 16,
|
|
"BLOCK_SIZE_N": 32,
|
|
"BLOCK_SIZE_K": 64,
|
|
"GROUP_SIZE_M": 1,
|
|
}
|
|
else:
|
|
config = {
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_N": 64,
|
|
"BLOCK_SIZE_K": 32,
|
|
"GROUP_SIZE_M": 8,
|
|
}
|
|
return config
|
|
|
|
|
|
def try_get_optimal_moe_config(
|
|
w1_shape: tuple[int, ...],
|
|
w2_shape: tuple[int, ...],
|
|
top_k: int,
|
|
dtype: Optional[str],
|
|
M: int,
|
|
is_marlin: bool = False,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> dict[str, int]:
|
|
from vllm.model_executor.layers.fused_moe import get_config
|
|
override_config = get_config()
|
|
if override_config:
|
|
config = override_config
|
|
else:
|
|
# First try to load optimal config from the file
|
|
E, _, N = w2_shape
|
|
if dtype == "int4_w4a16":
|
|
N = N * 2
|
|
block_n = block_shape[0] if block_shape else 0
|
|
block_k = block_shape[1] if block_shape else 0
|
|
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
|
|
|
if configs:
|
|
# If an optimal configuration map has been found, look up the
|
|
# optimal config
|
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
else:
|
|
# Else use the default config
|
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
|
is_marlin, block_shape)
|
|
return config
|
|
|
|
|
|
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
|
|
token_expert_indices: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
|
ops.topk_softmax(
|
|
topk_weights,
|
|
topk_indices,
|
|
token_expert_indices,
|
|
gating_output,
|
|
)
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
return topk_weights, topk_indices
|
|
|
|
|
|
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
|
if is_rocm_aiter_moe_enabled():
|
|
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
|
return rocm_aiter_topk_softmax
|
|
return vllm_topk_softmax
|
|
|
|
|
|
def fused_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
indices_type: Optional[torch.dtype] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
assert hidden_states.size(0) == gating_output.size(0), (
|
|
"Number of tokens mismatch")
|
|
|
|
M, _ = hidden_states.size()
|
|
|
|
topk_weights = torch.empty(M,
|
|
topk,
|
|
dtype=torch.float32,
|
|
device=hidden_states.device)
|
|
topk_ids = torch.empty(
|
|
M,
|
|
topk,
|
|
dtype=torch.int32 if indices_type is None else indices_type,
|
|
device=hidden_states.device)
|
|
token_expert_indices = torch.empty(M,
|
|
topk,
|
|
dtype=torch.int32,
|
|
device=hidden_states.device)
|
|
|
|
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
|
|
|
topk_func = dispatch_topk_func()
|
|
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
|
|
token_expert_indices,
|
|
gating_output_float, renormalize)
|
|
|
|
return topk_weights, topk_ids, token_expert_indices
|
|
|
|
|
|
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
|
def grouped_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
assert hidden_states.size(0) == gating_output.size(0), (
|
|
"Number of tokens mismatch")
|
|
|
|
if scoring_func == "softmax":
|
|
scores = torch.softmax(gating_output, dim=-1)
|
|
elif scoring_func == "sigmoid":
|
|
scores = gating_output.sigmoid()
|
|
else:
|
|
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
|
|
|
num_token = scores.size(0)
|
|
if e_score_correction_bias is not None:
|
|
# Store original scores before applying correction bias. We use biased
|
|
# scores for expert selection but original scores for routing weights
|
|
original_scores = scores
|
|
scores = scores + e_score_correction_bias.unsqueeze(0)
|
|
group_scores = (scores.view(num_token, num_expert_group,
|
|
-1).topk(2, dim=-1)[0].sum(dim=-1))
|
|
else:
|
|
group_scores = scores.view(num_token, num_expert_group,
|
|
-1).max(dim=-1).values # [n, n_group]
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
|
|
sorted=False)[1] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = group_mask.unsqueeze(-1).expand(
|
|
num_token, num_expert_group,
|
|
scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e]
|
|
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
|
float("-inf")) # [n, e]
|
|
|
|
if e_score_correction_bias is not None:
|
|
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
|
# Use original unbiased scores for the routing weights
|
|
topk_weights = original_scores.gather(1, topk_ids)
|
|
else:
|
|
topk_weights, topk_ids = torch.topk(tmp_scores,
|
|
k=topk,
|
|
dim=-1,
|
|
sorted=False)
|
|
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
|
|
|
|
def get_config_dtype_str(
|
|
dtype: torch.dtype,
|
|
use_int4_w4a16: Optional[bool] = False,
|
|
use_int8_w8a16: Optional[bool] = False,
|
|
use_fp8_w8a8: Optional[bool] = False) -> Optional[str]:
|
|
if use_fp8_w8a8:
|
|
return "fp8_w8a8"
|
|
elif use_int8_w8a16:
|
|
return "int8_w8a16"
|
|
elif use_int4_w4a16:
|
|
return "int4_w4a16"
|
|
elif dtype == torch.float:
|
|
# avoiding cases where kernel fails when float32 MoE
|
|
# use fp16/bfloat16 configs
|
|
return "float32"
|
|
return None
|
|
|
|
|
|
def inplace_fused_experts(hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None) -> None:
|
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
|
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
|
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
|
per_channel_quant, global_num_experts, expert_map,
|
|
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
|
block_shape)
|
|
|
|
|
|
def inplace_fused_experts_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None) -> None:
|
|
pass
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="inplace_fused_experts",
|
|
op_func=inplace_fused_experts,
|
|
mutates_args=["hidden_states"],
|
|
fake_impl=inplace_fused_experts_fake,
|
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
)
|
|
|
|
|
|
def outplace_fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
|
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
|
False, activation, apply_router_weight_on_input,
|
|
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
|
use_int4_w4a16, per_channel_quant,
|
|
global_num_experts, expert_map, w1_scale,
|
|
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
|
block_shape)
|
|
|
|
|
|
def outplace_fused_experts_fake(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="outplace_fused_experts",
|
|
op_func=outplace_fused_experts,
|
|
mutates_args=[],
|
|
fake_impl=outplace_fused_experts_fake,
|
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
)
|
|
|
|
|
|
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
|
torch.ops.vllm.inplace_fused_experts(**kwargs)
|
|
hidden_states = kwargs['hidden_states']
|
|
return hidden_states
|
|
|
|
|
|
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
|
return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
|
|
|
|
|
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
|
if inplace:
|
|
return torch_vllm_inplace_fused_experts
|
|
return torch_vllm_outplace_fused_experts
|
|
|
|
|
|
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
|
# torch ops.
|
|
def fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
inplace: bool = False,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None,
|
|
allow_deep_gemm: bool = False,
|
|
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
|
|
# For now, disable DeepGemm for small N (<= 512) until better
|
|
# permute/unpermute ops are available.
|
|
N = w1.size(1)
|
|
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
|
and _valid_deep_gemm(hidden_states, w1, w2)):
|
|
assert apply_router_weight_on_input is False
|
|
return deep_gemm_moe_fp8(
|
|
hidden_states=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=inplace,
|
|
activation=activation,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
)
|
|
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
|
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
|
|
assert apply_router_weight_on_input is False
|
|
return run_cutlass_block_scaled_fused_experts(
|
|
a=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids)
|
|
else:
|
|
return dispatch_fused_experts_func(inplace)(
|
|
hidden_states=hidden_states,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_channel_quant=per_channel_quant,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
w1_zp=w1_zp,
|
|
w2_zp=w2_zp,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
block_shape=block_shape)
|
|
|
|
|
|
def fused_experts_impl(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
inplace: bool = False,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
# Check constraints.
|
|
if use_int4_w4a16:
|
|
assert hidden_states.size(1) // 2 == w1.size(2), (
|
|
"Hidden size mismatch")
|
|
else:
|
|
assert hidden_states.size(1) == w1.size(2), (
|
|
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
|
|
|
|
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert hidden_states.dtype in [
|
|
torch.float32, torch.float16, torch.bfloat16
|
|
]
|
|
|
|
num_tokens = hidden_states.size(0)
|
|
E, N, _ = w1.size()
|
|
K = w2.size(1)
|
|
if global_num_experts == -1:
|
|
global_num_experts = E
|
|
top_k_num = topk_ids.size(1)
|
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
|
# https://github.com/vllm-project/vllm/issues/5938
|
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
M = min(num_tokens, CHUNK_SIZE)
|
|
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
dtype=hidden_states.dtype)
|
|
|
|
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16)
|
|
|
|
get_config_func = functools.partial(
|
|
try_get_optimal_moe_config,
|
|
w1.size(),
|
|
w2.size(),
|
|
top_k_num,
|
|
config_dtype,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
config = get_config_func(M)
|
|
|
|
# We can reuse the memory between these because by the time we need
|
|
# cache3, we're done with cache1
|
|
cache13 = torch.empty(M * top_k_num * max(N, K),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype)
|
|
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
|
|
intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K)
|
|
|
|
# This needs separate memory since it's used concurrently with cache1
|
|
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype)
|
|
|
|
if hidden_states.dtype == torch.bfloat16:
|
|
compute_type = tl.bfloat16
|
|
elif hidden_states.dtype == torch.float16:
|
|
compute_type = tl.float16
|
|
elif hidden_states.dtype == torch.float32:
|
|
compute_type = tl.float32
|
|
else:
|
|
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
|
|
|
if inplace:
|
|
out_hidden_states = hidden_states
|
|
else:
|
|
out_hidden_states = torch.empty_like(hidden_states)
|
|
|
|
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
|
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
|
min((chunk + 1) * CHUNK_SIZE,
|
|
num_tokens))
|
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
|
tokens_in_chunk, _ = curr_hidden_states.size()
|
|
|
|
if tokens_in_chunk == 0:
|
|
break
|
|
|
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
|
# Adjust the intermediate cache size and config for the last
|
|
# chunk. Note that in most cases we only have one chunk
|
|
# so the cache size and config are already set correctly and
|
|
# do not need to be adjusted.
|
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
|
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
|
|
topk_ids.size(1)]
|
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
|
config = get_config_func(tokens_in_chunk)
|
|
|
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
|
|
|
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
|
A=curr_hidden_states,
|
|
A_scale=a1_scale,
|
|
quant_dtype=qtype,
|
|
per_act_token_quant=per_channel_quant,
|
|
block_shape=block_shape)
|
|
|
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
|
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
|
global_num_experts, expert_map))
|
|
|
|
invoke_fused_moe_kernel(qcurr_hidden_states,
|
|
w1,
|
|
intermediate_cache1,
|
|
a1q_scale,
|
|
w1_scale,
|
|
w1_zp,
|
|
curr_topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
apply_router_weight_on_input,
|
|
top_k_num,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_channel_quant=per_channel_quant,
|
|
block_shape=block_shape)
|
|
|
|
if activation == "silu":
|
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
|
intermediate_cache1.view(-1, N))
|
|
elif activation == "gelu":
|
|
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
|
intermediate_cache1.view(-1, N))
|
|
else:
|
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
|
|
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
|
A=intermediate_cache2,
|
|
A_scale=a2_scale,
|
|
quant_dtype=qtype,
|
|
per_act_token_quant=per_channel_quant,
|
|
block_shape=block_shape)
|
|
|
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
|
w2,
|
|
intermediate_cache3,
|
|
a2q_scale,
|
|
w2_scale,
|
|
w2_zp,
|
|
curr_topk_weights,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
not apply_router_weight_on_input,
|
|
1,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_channel_quant=per_channel_quant,
|
|
block_shape=block_shape)
|
|
|
|
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
|
|
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
|
|
|
return out_hidden_states
|
|
|
|
|
|
def fused_moe(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
inplace: bool = False,
|
|
activation: str = "silu",
|
|
use_grouped_topk: bool = False,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_channel_quant: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_zp: Optional[torch.Tensor] = None,
|
|
w2_zp: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
|
weights, w1 and w2, and top-k gating mechanism.
|
|
|
|
Parameters:
|
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
|
- w1 (torch.Tensor): The first set of expert weights.
|
|
- w2 (torch.Tensor): The second set of expert weights.
|
|
- gating_output (torch.Tensor): The output of the gating operation
|
|
(before softmax).
|
|
- topk (int): The number of top-k experts to select.
|
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
|
- inplace (bool): If True, perform the operation in-place.
|
|
Defaults to False.
|
|
- activation (str): The activation function to apply after the first
|
|
MoE layer.
|
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
|
note: Deepseekv2 model uses grouped_topk
|
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
|
products for w1 and w2. Defaults to False.
|
|
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
|
products for w1 and w2. Defaults to False.
|
|
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
|
activation to compute the inner products for w1 and w2.
|
|
Defaults to False.
|
|
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
|
activation to compute the inner products for w1 and w2.
|
|
Defaults to False.
|
|
- global_num_experts (int): The total number of experts in the global
|
|
expert space.
|
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
|
from the global expert space to the local expert space of the expert
|
|
parallel shard.
|
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
w1.
|
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
w2.
|
|
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
a1.
|
|
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
a2.
|
|
- block_shape: (Optional[list[int]]): Optional block size for block-wise
|
|
quantization.
|
|
|
|
Returns:
|
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
|
"""
|
|
|
|
if use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
|
|
topk, renormalize,
|
|
num_expert_group, topk_group)
|
|
elif custom_routing_function is None:
|
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
hidden_states, gating_output, topk, renormalize)
|
|
else:
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
hidden_states, gating_output, topk, renormalize)
|
|
|
|
return fused_experts(hidden_states,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
inplace=inplace,
|
|
activation=activation,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_channel_quant=per_channel_quant,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
w1_zp=w1_zp,
|
|
w2_zp=w2_zp,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
block_shape=block_shape)
|
|
|
|
|
|
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|
|
|
def __init__(
|
|
self,
|
|
use_fp8_w8a8: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int8_w8a16: bool = False,
|
|
use_int4_w4a16: bool = False,
|
|
per_act_token_quant: bool = False,
|
|
block_shape: Optional[list[int]] = None,
|
|
):
|
|
super().__init__(
|
|
FusedMoEQuantConfig.make(
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
))
|
|
|
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
|
self.use_int4_w4a16 = use_int4_w4a16
|
|
self.use_int8_w8a8 = use_int8_w8a8
|
|
self.use_int8_w8a16 = use_int8_w8a16
|
|
|
|
@property
|
|
def activation_formats(
|
|
self
|
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
|
return (mk.FusedMoEActivationFormat.Standard,
|
|
mk.FusedMoEActivationFormat.Standard)
|
|
|
|
def supports_chunking(self) -> bool:
|
|
return True
|
|
|
|
def supports_expert_map(self) -> bool:
|
|
return True
|
|
|
|
def workspace_shapes(
|
|
self,
|
|
a: torch.Tensor,
|
|
aq: torch.Tensor,
|
|
M: int,
|
|
N: int,
|
|
K: int,
|
|
topk: int,
|
|
global_num_experts: int,
|
|
local_num_experts: int,
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
|
workspace1 = (M, topk, max(N * 2, K))
|
|
workspace2 = (M, topk, N)
|
|
output = (M, topk, K)
|
|
return (workspace1, workspace2, output, a.dtype)
|
|
|
|
def apply(
|
|
self,
|
|
output: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str,
|
|
global_num_experts: int,
|
|
expert_map: Optional[torch.Tensor],
|
|
w1_scale: Optional[torch.Tensor],
|
|
w2_scale: Optional[torch.Tensor],
|
|
w1_zp: Optional[torch.Tensor],
|
|
w2_zp: Optional[torch.Tensor],
|
|
a1q_scale: Optional[torch.Tensor],
|
|
a2_scale: Optional[torch.Tensor],
|
|
workspace13: torch.Tensor,
|
|
workspace2: torch.Tensor,
|
|
expert_num_tokens: Optional[torch.Tensor],
|
|
):
|
|
# Check constraints.
|
|
if self.use_int4_w4a16:
|
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
|
"Hidden size mismatch")
|
|
else:
|
|
assert hidden_states.size(-1) == w1.size(2), \
|
|
(f"Hidden size mismatch {hidden_states.size(-1)} "
|
|
f"!= {w1.size(2)}")
|
|
|
|
assert hidden_states.is_contiguous(
|
|
), "Hidden_states must be contiguous"
|
|
assert hidden_states.dim() == 2
|
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert hidden_states.dtype in [
|
|
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
|
]
|
|
|
|
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
|
hidden_states, w1, w2, topk_ids)
|
|
|
|
if global_num_experts == -1:
|
|
global_num_experts = E
|
|
|
|
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
|
use_int8_w8a16=self.use_int8_w8a16,
|
|
use_int4_w4a16=self.use_int4_w4a16,
|
|
dtype=hidden_states.dtype)
|
|
|
|
config = try_get_optimal_moe_config(
|
|
w1.size(),
|
|
w2.size(),
|
|
top_k_num,
|
|
config_dtype,
|
|
num_tokens,
|
|
block_shape=self.block_shape,
|
|
)
|
|
|
|
if hidden_states.dtype == torch.bfloat16:
|
|
compute_type = tl.bfloat16
|
|
elif hidden_states.dtype == torch.float16:
|
|
compute_type = tl.float16
|
|
elif hidden_states.dtype == torch.float32:
|
|
compute_type = tl.float32
|
|
elif hidden_states.dtype == torch.float8_e4m3fn:
|
|
compute_type = tl.bfloat16
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported compute_type: {hidden_states.dtype}")
|
|
|
|
# We can reuse the memory between these because by the time we need
|
|
# cache3, we're done with cache1
|
|
intermediate_cache1 = _resize_cache(workspace13,
|
|
(num_tokens, top_k_num, N))
|
|
intermediate_cache2 = _resize_cache(workspace2,
|
|
(num_tokens * top_k_num, N // 2))
|
|
|
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
|
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
|
global_num_experts, expert_map))
|
|
|
|
invoke_fused_moe_kernel(hidden_states,
|
|
w1,
|
|
intermediate_cache1,
|
|
a1q_scale,
|
|
w1_scale,
|
|
w1_zp,
|
|
None,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
False,
|
|
top_k_num,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
use_int8_w8a8=self.use_int8_w8a8,
|
|
use_int8_w8a16=self.use_int8_w8a16,
|
|
use_int4_w4a16=self.use_int4_w4a16,
|
|
per_channel_quant=self.per_act_token_quant,
|
|
block_shape=self.block_shape)
|
|
|
|
self.activation(activation, intermediate_cache2,
|
|
intermediate_cache1.view(-1, N))
|
|
|
|
a2q_scale: Optional[torch.Tensor] = None
|
|
|
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
|
intermediate_cache2, a2_scale, self.quant_dtype,
|
|
self.per_act_token_quant, self.block_shape)
|
|
|
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
|
w2,
|
|
output,
|
|
a2q_scale,
|
|
w2_scale,
|
|
w2_zp,
|
|
None,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
False,
|
|
1,
|
|
config,
|
|
compute_type=compute_type,
|
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
use_int8_w8a8=self.use_int8_w8a8,
|
|
use_int8_w8a16=self.use_int8_w8a16,
|
|
use_int4_w4a16=self.use_int4_w4a16,
|
|
per_channel_quant=self.per_act_token_quant,
|
|
block_shape=self.block_shape)
|
|
|
|
|
|
def modular_triton_fused_moe(
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
use_int4_w4a16: bool,
|
|
per_act_token_quant: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> mk.FusedMoEModularKernel:
|
|
return mk.FusedMoEModularKernel(
|
|
MoEPrepareAndFinalizeNoEP(),
|
|
TritonExperts(
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
use_int8_w8a8=use_int8_w8a8,
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
),
|
|
)
|