[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)
#
# For the Marlin kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
)
if (NOT marlin_generation_result EQUAL 0)
message(FATAL_ERROR "Marlin generation failed."
" Result: \"${marlin_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
CACHE STRING "Last run Marlin generate script hash" FORCE)
message(STATUS "Marlin generation completed successfully.")
endif()
else()
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output

1
csrc/moe/marlin_moe_wna16/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
kernel_*.cu

View File

@ -25,15 +25,13 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
@ -52,21 +50,29 @@ def remove_old_kernels():
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
has_act_order = group_blocks == 0
if has_zp and has_act_order:
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
@ -82,8 +88,6 @@ def generate_new_kernels():
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)

View File

@ -18,7 +18,7 @@
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
@ -33,11 +33,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);

View File

@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -77,8 +76,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {}
} // namespace MARLIN_NAMESPACE_NAME
@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -458,8 +317,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
@ -481,6 +340,8 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
@ -534,13 +395,20 @@ __global__ void Marlin(
int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh;
int4* sh_rd_block_sorted_ids_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int4* sh_block_topk_weights_int4 =
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4* sh_new =
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int32_t* sh_rd_block_sorted_ids =
reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0;
@ -584,6 +452,11 @@ __global__ void Marlin(
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] =
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
@ -743,6 +616,7 @@ __global__ void Marlin(
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr int act_s_max_num_groups = 32;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
@ -758,9 +632,9 @@ __global__ void Marlin(
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
@ -774,8 +648,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@ -794,7 +668,7 @@ __global__ void Marlin(
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
@ -807,7 +681,7 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
@ -851,7 +725,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
@ -879,12 +753,28 @@ __global__ void Marlin(
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh_new;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_b;
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int sh_a_max_row =
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
@ -905,15 +795,14 @@ __global__ void Marlin(
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
if (sh_num_groups < act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
@ -940,27 +829,31 @@ __global__ void Marlin(
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int a_remaining_load_count_in_slice = stages;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
bool should_load_a = true;
int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 ||
a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
if (should_load_a) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8)
sorted_row = sh_block_sorted_ids[row] / top_k;
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx =
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens);
}
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
@ -1063,8 +956,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
@ -1109,12 +1002,17 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1152,7 +1050,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
@ -1161,7 +1059,7 @@ __global__ void Marlin(
cur_k += warp_row * 16;
int th_id = threadIdx.x % 32;
auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
@ -1222,15 +1120,18 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1251,6 +1152,7 @@ __global__ void Marlin(
sh_zp_stage += cur_group_id * zp_sh_stride;
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
@ -1263,12 +1165,16 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp +
zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd];
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1292,6 +1198,25 @@ __global__ void Marlin(
}
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
if constexpr (has_zp && is_zp_float || !has_zp) {
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
} else {
static_assert(has_zp && !is_zp_float);
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if constexpr (w_type_id == vllm::kU4.id()) {
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
} else if constexpr (w_type_id == vllm::kU8.id()) {
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) {
@ -1315,15 +1240,17 @@ __global__ void Marlin(
zp_quant_1 = frag_qzp[k2][1];
}
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
}
if constexpr (has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
@ -1342,8 +1269,8 @@ __global__ void Marlin(
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
// Apply scale to frag_b0
if constexpr (has_act_order) {
@ -1351,8 +1278,7 @@ __global__ void Marlin(
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
@ -1361,18 +1287,12 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
} else if constexpr (has_zp && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zpf[k2][j] = __hmul2(
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
@ -1397,7 +1317,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@ -1731,7 +1651,7 @@ __global__ void Marlin(
fetch_col_scale_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters);
fetch_to_shared(i, i, i < slice_iters, i);
}
zero_accums();
@ -1740,8 +1660,10 @@ __global__ void Marlin(
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
if constexpr (has_act_order) {
slice_k_start_shared_fetch += tb_k * (stages - 1);
}
};
if (slice_iters) {
start_pipes();
@ -1754,45 +1676,58 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at
// index 0.
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll
for (int pipe = 0; pipe < stages;) {
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
for (int k = 0; k < b_sh_wr_iters; k++) {
int idx =
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
? (pipe - stages)
: (pipe + stage_group_id * stages);
fetch_to_registers(k + 1, pipe % stages, idx);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_gl_rd_col += a_gl_rd_delta_o * stages;
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
__syncthreads();
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
@ -1877,15 +1812,30 @@ __global__ void Marlin(
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
if (slice_row) a_remaining_load_count_in_slice = stages;
int old_slice_row = slice_row;
slice_row = 0;
slice_col_par++;
slice_col++;
is_first_matmul_in_slice = true;
init_slice();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if (slice_col == 0 || old_slice_row ||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
should_load_a = true;
} else {
should_load_a = false;
}
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
@ -1900,12 +1850,10 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}

View File

@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
// shm size for block_sorted_ids/block_topk_weights
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2;
int sh_block_meta_size = tb_m * 4;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2;
}
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -266,143 +268,113 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
@ -415,23 +387,15 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto kernel = MarlinDefault;
if (false) {
}
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
return kernel;
}
@ -457,19 +421,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) {
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16;
group_blocks = group_size == -1 ? -1 : (group_size / 16);
}
auto kernel = get_marlin_kernel<scalar_t>(
@ -515,14 +479,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
TORCH_CHECK(q_type == vllm::kU4,
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn,
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -631,18 +595,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
TORCH_CHECK(
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
@ -666,7 +630,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on
}
@ -841,10 +805,11 @@ torch::Tensor moe_wna16_marlin_gemm(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type.str());
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn,
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = ",
b_q_type.str());
}
if (has_zp && is_zp_float) {

View File

@ -0,0 +1 @@
kernel_*.cu

View File

@ -0,0 +1,291 @@
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
(128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
is_zp_float_list = [False]
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
group_blocks == 4:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list.append(True)
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,37 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
int prob_k, int lda, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}

File diff suppressed because it is too large Load Diff

View File

@ -291,12 +291,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor",
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
@ -341,14 +340,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "

View File

@ -11,19 +11,20 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
torch_moe_single)
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.scalar_type import ScalarType, scalar_types
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
@ -285,7 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype])
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("m", [1, 123, 666])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@ -294,8 +295,10 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("quant_type", [
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types.float8_e4m3fn
])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
@ -308,14 +311,22 @@ def test_fused_marlin_moe(
dtype: torch.dtype,
group_size: int,
act_order: bool,
num_bits: int,
has_zp: bool,
quant_type: ScalarType,
is_k_full: bool,
):
current_platform.seed_everything(7)
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
# Filter act_order
if act_order:
if quant_type == scalar_types.float8_e4m3fn:
return
if group_size == -1:
return
if group_size in (k, n):
@ -326,17 +337,9 @@ def test_fused_marlin_moe(
if not is_k_full:
return
if has_zp:
# we don't build kernel for int8 with zero
if num_bits == 8:
return
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
if ep_size > 1:
local_e = e // ep_size
@ -364,17 +367,23 @@ def test_fused_marlin_moe(
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
elif quant_type != scalar_types.float8_e4m3fn:
test_perm = torch.randperm(k)
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
else:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
@ -399,17 +408,23 @@ def test_fused_marlin_moe(
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
elif quant_type != scalar_types.float8_e4m3fn:
test_perm = torch.randperm(n)
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
else:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
@ -442,102 +457,10 @@ def test_fused_marlin_moe(
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int,
act_order: bool, num_bits: int,
has_zp: bool, is_k_full: bool):
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
if has_zp:
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweight_l = []
scales_l = []
zeros_l = []
g_idx_l = []
sort_indices_l = []
for i in range(w.shape[0]):
if has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe(
a,
qweight,
scales,
score,
topk,
renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
w_zeros=zeros,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref, score, topk)
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():

View File

@ -1,164 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import pytest
import torch
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
from vllm.scalar_type import scalar_types
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
GROUP_SIZES = [-1, 32, 128]
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w_ref1_l = []
qweights1_l = []
scales1_l = []
zp1_l = []
for i in range(w1.shape[0]):
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1)
qweights1_l.append(qweight1)
scales1_l.append(scales1)
zp1_l.append(zp1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweights1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
zp1 = stack_and_dev(zp1_l)
w_ref2_l = []
qweights2_l = []
scales2_l = []
zp2_l = []
for i in range(w2.shape[0]):
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2)
qweights2_l.append(qweight2)
scales2_l.append(scales2)
zp2_l.append(zp2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweights2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
zp2 = stack_and_dev(zp2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, False)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
)
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
score, topk, None)
assert compute_max_diff(marlin_output, torch_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
zp_l = []
for i in range(w.shape[0]):
w_ref, qweight, scales, zp = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
zp_l.append(zp)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2

View File

@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.qqq import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_permute_scales, query_marlin_supported_quant_types)
marlin_make_workspace_new, marlin_permute_scales,
query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights)
@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@ -220,38 +221,50 @@ def test_gptq_marlin_gemm(
if group_size == size_k:
return
if size_k % group_size != 0:
return
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size)
g_idx = None
sort_indices = None
else:
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, False,
use_atomic_add, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
# WEIGHTS
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device="cuda"),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=size_k,
size_n=size_n,
group_size=-1)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
opcheck(torch.ops._C.fp8_marlin_gemm,
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
b_scales=marlin_scales,
workspace=workspace.scratch,
num_bits=num_bits,
size_m=a_input.shape[0],
size_n=b_weight.shape[1],
size_k=a_input.shape[1],
)
output_ref = torch.matmul(a_input, b_weight)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@ -432,25 +371,23 @@ def test_awq_marlin_gemm(
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
is_k_full = True
has_zp = True
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
@ -508,23 +445,22 @@ def test_hqq_marlin_gemm(
g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(b_weight.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)
@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input():
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
has_zp=False,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,

View File

@ -325,18 +325,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
b_q_type_id: int,
size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
@ -407,14 +407,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=codebooks.dtype,
device=codebooks.device)
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@register_fake("_C::machete_mm")
def machete_mm_fake(
a: torch.Tensor,
@ -815,35 +807,26 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
def gptq_marlin_gemm(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_atomic_add,
use_fp32_reduce, is_zp_float)
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
num_bits, size_m, size_n, size_k)
use_atomic_add, use_fp32_reduce,
is_zp_float)
# machete

View File

@ -7,163 +7,13 @@ import torch
import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
moe_align_block_size, try_get_optimal_moe_config)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
- sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // (num_bits // 2)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
# This might not be an optimal config for a single MMM
get_config_func = functools.partial(try_get_optimal_moe_config,
w.shape,
w.shape,
topk_ids.shape[1],
None,
is_marlin=True)
config = get_config_func(M)
block_size_m = config['BLOCK_SIZE_M']
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, E, expert_map)
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, w_zeros is not None)
intermediate_cache = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
ops.moe_wna16_marlin_gemm(hidden_states,
intermediate_cache,
w,
scales,
w_zeros,
g_idx,
sort_indices,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type,
size_m=M,
size_n=N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False)
intermediate_cache = intermediate_cache.view(-1, topk, N)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
def single_marlin_moe_fake(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="single_marlin_moe",
op_func=single_marlin_moe,
mutates_args=[],
fake_impl=single_marlin_moe_fake,
)
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -172,6 +22,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
@ -181,7 +32,6 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
"""
@ -211,6 +61,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types.float8_e4m3fn
]
int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8]
num_bits = 4 if quant_type in int4_scalar_types else 8
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[
0], "Number of tokens mismatch"
@ -248,18 +107,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
expert_map)
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms * 4)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
workspace = marlin_make_workspace_new(hidden_states.device, 4)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
@ -276,6 +124,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
@ -296,7 +145,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type1,
b_q_type=quant_type,
size_m=M,
size_n=2 * N,
size_k=K,
@ -328,7 +177,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
top_k=1,
mul_topk_weights=True,
is_ep=expert_map is not None,
b_q_type=scalar_type2,
b_q_type=quant_type,
size_m=M * topk,
size_n=K,
size_k=N,
@ -351,6 +200,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
@ -360,7 +210,6 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states)

View File

@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
marlin_make_empty_g_idx, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
layer.workspace = marlin_make_workspace_new(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
)
workspace=layer.workspace)

View File

@ -55,7 +55,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
prepare_fp8_layer_for_marlin(layer)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
@ -68,6 +68,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(

View File

@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
cutlass_fp8_supported, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.fp8_linear = Fp8LinearOp(
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic=cutlass_fp8_supported())
@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.block_quant:
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
return weight
def process_weights_after_loading(self, layer: Module) -> None:
size_k_first = True
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
size_k_first = False
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
return
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
weight = layer.weight
weight_scale = layer.weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
if not self.use_marlin:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
requires_grad=False)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
prepare_fp8_layer_for_marlin(layer, size_k_first)
# Activations not quantized for marlin.
del layer.input_scale
@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.quant_config.weight_block_size[0],
@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
elif not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
def apply(
self,
@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(
x,
layer.w13_weight,

View File

@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
marlin_make_workspace_new, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
else:
raise ValueError(
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
def create_weights(
self,
@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
workspace=layer.workspace,
is_k_full=self.is_k_full)

View File

@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
device)
self.workspace = marlin_make_workspace_new(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full,
bias=bias)

View File

@ -7,12 +7,15 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
def query_marlin_supported_quant_types(
has_zp: bool,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
return [scalar_types.uint4]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn]
return res
def _check_marlin_supported(
@ -62,7 +68,7 @@ def _check_marlin_supported(
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types(
has_zp, device_capability)
has_zp, True, device_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)
def marlin_make_workspace_new(device: torch.device,
max_blocks_per_sm: int = 1) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(sms * max_blocks_per_sm,
dtype=torch.int,
device=device,
requires_grad=False)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return output
def maybe_warn_marlin_atomic_add(device, dtype):
if torch.compiler.is_dynamo_compiling():
return
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible.")
def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
return
logger.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
return False
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
maybe_warn_marlin_atomic_add_env()
return False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
maybe_warn_marlin_atomic_add(device, dtype)
return False
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return n < 2048 and k >= 2048
return True
def apply_gptq_marlin_linear(
@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
has_zp: bool,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
weight_scale,
weight_zp,
@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
weight_scale,
weight_zp,
@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

View File

@ -6,9 +6,11 @@ import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=size_n,
k=size_k,
device=input.device,
dtype=input.dtype)
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_scales=weight_scale,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float8_e4m3fn,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
strategy: str = "tensor") -> None:
size_k_first: bool = True) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
if size_k_first:
assert layer.weight.shape == (part_size_k, part_size_n)
else:
assert layer.weight.shape == (part_size_n, part_size_k)
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace(part_size_n, device)
layer.workspace = marlin_make_workspace_new(device)
# WEIGHT
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
layer.weight),
perm=torch.empty(0,
dtype=torch.int,
device=device),
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=8)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales
if "weight_scale" in dir(layer):
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if scales.nelement() == 1:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
assert part_size_n % scales.nelement() == 0
s_size = scales.nelement()
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (1, s_size) =>(repeat)=> (1, size_n)
scales = scales.view(1, s_size)
scales = scales.repeat_interleave(part_size_n // s_size, 1)
else:
# channel-wise quantization
# (1, size_n)
scales = scales.view(1, part_size_n)
else:
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1)
group_size=group_size)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k_first: bool = True) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
# WORKSPACE
device = layer.w13_weight.device
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
if size_k_first:
assert weight.shape == (e, size_k, size_n)
else:
assert weight.shape == (e, size_n, size_k)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=8)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight)
# WEIGHT SCALES
# Permute scales
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
for name in ["w13", "w2"]:
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if scales.nelement() == e:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
elif scales.nelement() > e and scales.nelement() != e * size_n:
assert (e * size_n) % scales.nelement() == 0
s_size = scales.nelement() // e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, s_size)
scales = scales.repeat_interleave(size_n // s_size, 2)
else:
# channel-wise quantization
# (e, 1, size_n)
scales = scales.view(e, 1, size_n)
else:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i],
size_k=size_k,
size_n=size_n,
group_size=group_size)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
size_k_first: bool = True) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
assert fp8_tensor.ndim == 2
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (byte_tensor[:, 0].to(torch.int32) |
(byte_tensor[:, 1].to(torch.int32) << 8) |
(byte_tensor[:, 2].to(torch.int32) << 16) |
(byte_tensor[:, 3].to(torch.int32) << 24))
def marlin_quant_fp8_torch(weight, group_size):
size_n, size_k = weight.shape
device = weight.device
return packed.view(fp8_tensor.shape[0] // 4,
*fp8_tensor.shape[1:]).contiguous()
if group_size != -1:
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(group_size, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
else:
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(size_k, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
marlin_scales = marlin_permute_scales(s=scales.T,
size_k=size_k,
size_n=size_n,
group_size=group_size)
return weight_ref.T, marlin_qweight, marlin_scales

View File

@ -6,6 +6,8 @@ from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
_SCALAR_TYPES_ID_MAP = {}
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
@ -158,6 +160,8 @@ class ScalarType:
assert offset <= 64, \
f"ScalarType fields too big {offset} to fit into an int64"
_SCALAR_TYPES_ID_MAP[val] = self
return val
@property
@ -295,6 +299,13 @@ class ScalarType:
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def from_id(cls, scalar_type_id: int):
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
raise ValueError(
f"scalar_type_id {scalar_type_id} doesn't exists.")
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is: