mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
|
||||
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE marlin_generation_result
|
||||
OUTPUT_VARIABLE marlin_generation_result
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin generation failed."
|
||||
" Result: \"${marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
|
||||
else()
|
||||
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin generate script hash" FORCE)
|
||||
message(STATUS "Marlin generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
|
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
Normal file
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
kernel_*.cu
|
@ -25,15 +25,13 @@ TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
@ -52,21 +50,29 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8", "vllm::kU8B128"
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
# we only support channelwise quantization and group_size == 128
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
@ -82,8 +88,6 @@ def generate_new_kernels():
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
@ -18,7 +18,7 @@
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
@ -33,11 +33,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
|
@ -25,6 +25,7 @@
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
@ -77,8 +76,8 @@ __global__ void Marlin(
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {}
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
int max_shared_mem) {}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t, int bit>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
|
||||
int q, typename ScalarType<scalar_t>::FragB& frag_b);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
|
||||
int q, typename ScalarType<half>::FragB& frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant<nv_bfloat16, 4>(int q,
|
||||
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC308C308;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
|
||||
int q, typename ScalarType<half>::FragB& frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant<nv_bfloat16, 8>(int q,
|
||||
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
template <typename scalar_t>
|
||||
@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
@ -458,8 +317,8 @@ __global__ void Marlin(
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
int max_shared_mem) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
||||
@ -481,6 +340,8 @@ __global__ void Marlin(
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
|
||||
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||
static_assert(thread_m_blocks == 1 || !m_block_size_8);
|
||||
@ -534,13 +395,20 @@ __global__ void Marlin(
|
||||
int64_t B_expert_off = 0;
|
||||
|
||||
int4* sh_block_sorted_ids_int4 = sh;
|
||||
int4* sh_rd_block_sorted_ids_int4 =
|
||||
sh_block_sorted_ids_int4 + moe_block_size / 4;
|
||||
int4* sh_block_topk_weights_int4 =
|
||||
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
|
||||
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
|
||||
// but we pad to align to 256 bytes
|
||||
int4* sh_new =
|
||||
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
|
||||
int32_t* sh_block_sorted_ids =
|
||||
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
|
||||
int4* sh_block_topk_weights_int4 =
|
||||
sh_block_sorted_ids_int4 + moe_block_size / 4;
|
||||
int32_t* sh_rd_block_sorted_ids =
|
||||
reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
|
||||
scalar_t2* sh_block_topk_weights =
|
||||
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
|
||||
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
|
||||
|
||||
int32_t block_num_valid_tokens = 0;
|
||||
int32_t locks_off = 0;
|
||||
@ -584,6 +452,11 @@ __global__ void Marlin(
|
||||
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
|
||||
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
sh_rd_block_sorted_ids[tid4 * 4 + i] =
|
||||
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
|
||||
|
||||
if (mul_topk_weights) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
@ -743,6 +616,7 @@ __global__ void Marlin(
|
||||
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
|
||||
// constexpr int act_s_row_stride = 1;
|
||||
// int act_s_col_stride = act_s_row_stride * num_groups;
|
||||
constexpr int act_s_max_num_groups = 32;
|
||||
int act_s_col_stride = 1;
|
||||
int act_s_col_warp_stride = act_s_col_stride * 8;
|
||||
int tb_n_warps = thread_n_blocks / 4;
|
||||
@ -758,9 +632,9 @@ __global__ void Marlin(
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
|
||||
// Global A read index of current thread.
|
||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
(threadIdx.x % a_gl_rd_delta_o);
|
||||
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
||||
int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
|
||||
int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
|
||||
|
||||
// Shared write index of current thread.
|
||||
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
(threadIdx.x % a_gl_rd_delta_o);
|
||||
@ -774,8 +648,8 @@ __global__ void Marlin(
|
||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
|
||||
// For act_order
|
||||
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
||||
@ -794,7 +668,7 @@ __global__ void Marlin(
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
@ -807,7 +681,7 @@ __global__ void Marlin(
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
auto zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
@ -851,7 +725,7 @@ __global__ void Marlin(
|
||||
// each warp must also write a consecutive memory segment?
|
||||
auto transform_a = [&](int i) {
|
||||
int row = i / a_gl_rd_delta_o;
|
||||
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
||||
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
|
||||
};
|
||||
// Since the computation of this remapping is non-trivial and, due to our main
|
||||
// loop unrolls, all shared memory accesses are static, we simply precompute
|
||||
@ -879,12 +753,28 @@ __global__ void Marlin(
|
||||
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
||||
|
||||
// Shared memory storage for global fetch pipelines.
|
||||
int4* sh_a = sh_new;
|
||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
|
||||
constexpr int sh_b_size = stages * b_sh_stage;
|
||||
int4* sh_b = sh_new;
|
||||
int4* sh_red = sh_new;
|
||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
int4* sh_red = sh_b;
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
constexpr int shm_size_used =
|
||||
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
|
||||
// all remaining shared memory is used to cache A (input)
|
||||
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
|
||||
int sh_a_max_row =
|
||||
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
@ -905,15 +795,14 @@ __global__ void Marlin(
|
||||
|
||||
int sh_first_group_id = -1;
|
||||
int sh_num_groups = -1;
|
||||
constexpr int sh_max_num_groups = 32;
|
||||
|
||||
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
|
||||
int last_group_id) {
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups < sh_max_num_groups) {
|
||||
sh_num_groups = sh_max_num_groups;
|
||||
if (sh_num_groups < act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
if (sh_first_group_id + sh_num_groups > num_groups) {
|
||||
@ -940,27 +829,31 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Asynchronously fetch the next A, B and s tile from global to the next
|
||||
// shared memory pipeline location.
|
||||
int a_remaining_load_count_in_slice = stages;
|
||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
||||
bool should_load_a = true;
|
||||
int max_num_stage_groups =
|
||||
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
|
||||
max_num_stage_groups = max(max_num_stage_groups, 1);
|
||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
|
||||
int pipe_a = 0) {
|
||||
if (pred) {
|
||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 ||
|
||||
a_remaining_load_count_in_slice > 0) {
|
||||
a_remaining_load_count_in_slice--;
|
||||
if (should_load_a) {
|
||||
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
|
||||
int row = a_idx / a_gl_stride;
|
||||
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
|
||||
int64_t sorted_row = 0;
|
||||
if (!m_block_size_8 || row < 8)
|
||||
sorted_row = sh_block_sorted_ids[row] / top_k;
|
||||
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
|
||||
sorted_row = sh_rd_block_sorted_ids[row];
|
||||
int64_t true_idx =
|
||||
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
|
||||
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
|
||||
row < block_num_valid_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
@ -1063,8 +956,8 @@ __global__ void Marlin(
|
||||
|
||||
// Load the next sub-tile from the current location in the shared memory pipe
|
||||
// into the current register buffer.
|
||||
auto fetch_to_registers = [&](int k, int pipe) {
|
||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
|
||||
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++)
|
||||
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
|
||||
@ -1109,12 +1002,17 @@ __global__ void Marlin(
|
||||
}
|
||||
} else if constexpr (group_blocks != -1) {
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_s_stage =
|
||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
if (k % b_sh_wr_iters == 0) {
|
||||
int4* sh_s_stage =
|
||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
} else {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1152,7 +1050,7 @@ __global__ void Marlin(
|
||||
|
||||
// Determine "position" inside the thread-block (based on warp and
|
||||
// thread-id)
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps =
|
||||
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
||||
|
||||
@ -1161,7 +1059,7 @@ __global__ void Marlin(
|
||||
|
||||
cur_k += warp_row * 16;
|
||||
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
||||
|
||||
int s_col_shift =
|
||||
@ -1222,15 +1120,18 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
if (k % b_sh_wr_iters == 0) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1251,6 +1152,7 @@ __global__ void Marlin(
|
||||
|
||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
@ -1263,12 +1165,16 @@ __global__ void Marlin(
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
||||
if (k % b_sh_wr_iters == 0) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp +
|
||||
zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
|
||||
sh_zp_stage[zp_sh_rd];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1292,6 +1198,25 @@ __global__ void Marlin(
|
||||
}
|
||||
};
|
||||
|
||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||
} else {
|
||||
static_assert(has_zp && !is_zp_float);
|
||||
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||
// If (has_zp && !is_zp_float),
|
||||
// we use not-zp version `dequant` function
|
||||
// to improve numerical accuracy.
|
||||
// Since both weight and zero point are dequanted using this logic,
|
||||
// the final dequanted weight would be correct.
|
||||
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
bool is_first_matmul_in_slice = true;
|
||||
auto matmul = [&](int k) {
|
||||
@ -1315,15 +1240,17 @@ __global__ void Marlin(
|
||||
zp_quant_1 = frag_qzp[k2][1];
|
||||
}
|
||||
|
||||
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
|
||||
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
|
||||
|
||||
frag_zp[0] = frag_zp_0[0];
|
||||
frag_zp[1] = frag_zp_0[1];
|
||||
frag_zp[2] = frag_zp_1[0];
|
||||
frag_zp[3] = frag_zp_1[1];
|
||||
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
|
||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||
}
|
||||
}
|
||||
if constexpr (has_zp && is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||
}
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -1342,8 +1269,8 @@ __global__ void Marlin(
|
||||
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
}
|
||||
|
||||
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
|
||||
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
|
||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
@ -1351,8 +1278,7 @@ __global__ void Marlin(
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
|
||||
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2 s2 = Dtype::nums2num2(
|
||||
@ -1361,18 +1287,12 @@ __global__ void Marlin(
|
||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
|
||||
} else if constexpr (has_zp && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zp[j] = __hmul2(frag_zp[j],
|
||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zpf[k2][j] = __hmul2(
|
||||
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
|
||||
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
|
||||
} else if constexpr (group_blocks != -1) {
|
||||
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
|
||||
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
|
||||
@ -1397,7 +1317,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||
@ -1731,7 +1651,7 @@ __global__ void Marlin(
|
||||
fetch_col_scale_to_shared();
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
fetch_to_shared(i, i, i < slice_iters, i);
|
||||
}
|
||||
|
||||
zero_accums();
|
||||
@ -1740,8 +1660,10 @@ __global__ void Marlin(
|
||||
fetch_to_registers(0, 0);
|
||||
fetch_scales_to_registers(0, 0);
|
||||
fetch_zp_to_registers(0, 0);
|
||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||
a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||
}
|
||||
};
|
||||
if (slice_iters) {
|
||||
start_pipes();
|
||||
@ -1754,45 +1676,58 @@ __global__ void Marlin(
|
||||
// have even length meaning that the next iteration will always start at
|
||||
// index 0.
|
||||
|
||||
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
|
||||
stage_group_id++) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < stages;) {
|
||||
for (int pipe = 0; pipe < stages;) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
fetch_to_registers(k + 1, pipe % stages);
|
||||
fetch_scales_to_registers(k + 1, pipe);
|
||||
fetch_zp_to_registers(k + 1, pipe);
|
||||
if (k == b_sh_wr_iters - 2) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages);
|
||||
pipe++;
|
||||
wait_for_stage();
|
||||
init_same_group(pipe % stages);
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
int idx =
|
||||
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
|
||||
? (pipe - stages)
|
||||
: (pipe + stage_group_id * stages);
|
||||
fetch_to_registers(k + 1, pipe % stages, idx);
|
||||
fetch_scales_to_registers(k + 1, pipe);
|
||||
fetch_zp_to_registers(k + 1, pipe);
|
||||
if (k == b_sh_wr_iters - 2) {
|
||||
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
|
||||
? (pipe - 1)
|
||||
: (pipe + (stage_group_id + 1) * stages - 1);
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages, idx);
|
||||
pipe++;
|
||||
wait_for_stage();
|
||||
init_same_group(pipe % stages);
|
||||
}
|
||||
matmul(k);
|
||||
}
|
||||
slice_iters--;
|
||||
if (slice_iters == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
a_gl_rd_col += a_gl_rd_delta_o * stages;
|
||||
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start += tb_k * stages;
|
||||
slice_k_start_shared_fetch += tb_k * stages;
|
||||
int first_group_id = g_idx[slice_k_start];
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||
last_group_id);
|
||||
__syncthreads();
|
||||
}
|
||||
matmul(k);
|
||||
}
|
||||
slice_iters--;
|
||||
if (slice_iters == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
a_remaining_load_count_in_slice = 0;
|
||||
|
||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||
slice_k_start += tb_k * stages;
|
||||
slice_k_start_shared_fetch += tb_k * stages;
|
||||
|
||||
if constexpr (has_act_order) {
|
||||
int first_group_id = g_idx[slice_k_start];
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// Process results and, if necessary, proceed to the next column slice.
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
@ -1877,15 +1812,30 @@ __global__ void Marlin(
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actually writes the result
|
||||
write_result();
|
||||
if (slice_row) a_remaining_load_count_in_slice = stages;
|
||||
int old_slice_row = slice_row;
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
slice_col++;
|
||||
is_first_matmul_in_slice = true;
|
||||
init_slice();
|
||||
|
||||
// Should we load A matrix in next slice?
|
||||
// `slice_col == 0`: when move to a new moe block
|
||||
// `old_slice_row > 0`:
|
||||
// when the last slice is not starting from k_index == 0
|
||||
// (only happen when it is the first slice of a threadblock)
|
||||
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
|
||||
// when the required shared memory size is larger than
|
||||
// the remaining shared memory
|
||||
if (slice_col == 0 || old_slice_row ||
|
||||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
|
||||
should_load_a = true;
|
||||
} else {
|
||||
should_load_a = false;
|
||||
}
|
||||
|
||||
if (slice_iters) {
|
||||
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
(threadIdx.x % a_gl_rd_delta_o);
|
||||
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||
@ -1900,12 +1850,10 @@ __global__ void Marlin(
|
||||
slice_k_finish = slice_k_start + tb_k * slice_iters;
|
||||
slice_k_start_shared_fetch = slice_k_start;
|
||||
slice_n_offset = act_s_col_tb_stride * slice_col;
|
||||
|
||||
} else {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
start_pipes();
|
||||
}
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
|
||||
|
||||
// shm size for block_sorted_ids/block_topk_weights
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 4 * 2;
|
||||
int sh_block_meta_size = tb_m * 4;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
||||
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -266,143 +268,113 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
IS_ZP_FLOAT>; \
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
// this is the most common cases
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
#define COMMON_GET_IF(W_TYPE) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true)
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF(W_TYPE) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF(W_TYPE) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
@ -415,23 +387,15 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
||||
COMMON_GET_IF(vllm::kU4)
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
||||
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
|
||||
return kernel;
|
||||
}
|
||||
@ -457,19 +421,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
group_blocks = group_size == -1 ? -1 : (group_size / 16);
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
@ -515,14 +479,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
TORCH_CHECK(q_type == vllm::kU4,
|
||||
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn,
|
||||
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||
"False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -631,18 +595,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
|
||||
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size, ", has_act_order = ", has_act_order,
|
||||
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
|
||||
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||
@ -666,7 +630,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -841,10 +805,11 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn,
|
||||
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||
"False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
|
1
csrc/quantization/gptq_marlin/.gitignore
vendored
Normal file
1
csrc/quantization/gptq_marlin/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
kernel_*.cu
|
291
csrc/quantization/gptq_marlin/dequant.h
Normal file
291
csrc/quantization/gptq_marlin/dequant.h
Normal file
@ -0,0 +1,291 @@
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// clang-format on
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC308C308;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
|
||||
half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
|
||||
half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
// Calculate MASK for extracting mantissa and exponent
|
||||
constexpr int MASK1 = 0x80000000;
|
||||
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||
// Final MASK value: 0x7F007F00
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
|
||||
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
// Calculate MASK for extracting mantissa and exponent
|
||||
constexpr int MASK1 = 0x80000000;
|
||||
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||
// Final MASK value: 0x7F007F00
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg =
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
|
||||
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
116
csrc/quantization/gptq_marlin/generate_kernels.py
Normal file
116
csrc/quantization/gptq_marlin/generate_kernels.py
Normal file
@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
||||
(128, 64, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8", "vllm::kU8B128"
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
# we only support channelwise quantization and group_size == 128
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
is_zp_float_list = [False]
|
||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
||||
group_blocks == 4:
|
||||
# HQQ (is_zp_float = true) only supports
|
||||
# 4bit quantization and fp16
|
||||
is_zp_float_list.append(True)
|
||||
|
||||
for is_zp_float in is_zp_float_list:
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=is_zp_float,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
File diff suppressed because it is too large
Load Diff
37
csrc/quantization/gptq_marlin/kernel.h
Normal file
37
csrc/quantization/gptq_marlin/kernel.h
Normal file
@ -0,0 +1,37 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
|
||||
int prob_k, int lda, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
1678
csrc/quantization/gptq_marlin/marlin_template.h
Normal file
1678
csrc/quantization/gptq_marlin/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -291,12 +291,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
|
||||
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
|
||||
"bool is_zp_float) -> Tensor",
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
@ -341,14 +340,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def(
|
||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def(
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
|
@ -11,19 +11,20 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
@ -285,7 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
||||
@pytest.mark.parametrize("m", [1, 123, 666])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@ -294,8 +295,10 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("quant_type", [
|
||||
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn
|
||||
])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
@ -308,14 +311,22 @@ def test_fused_marlin_moe(
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
has_zp: bool,
|
||||
quant_type: ScalarType,
|
||||
is_k_full: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
torch.cuda.manual_seed(0)
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
return
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
@ -326,17 +337,9 @@ def test_fused_marlin_moe(
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if has_zp:
|
||||
# we don't build kernel for int8 with zero
|
||||
if num_bits == 8:
|
||||
return
|
||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
quant_type = scalar_types.uint4b8 \
|
||||
if num_bits == 4 else scalar_types.uint8b128
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
@ -364,17 +367,23 @@ def test_fused_marlin_moe(
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
elif quant_type != scalar_types.float8_e4m3fn:
|
||||
test_perm = torch.randperm(k)
|
||||
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||
w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
@ -399,17 +408,23 @@ def test_fused_marlin_moe(
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
elif quant_type != scalar_types.float8_e4m3fn:
|
||||
test_perm = torch.randperm(n)
|
||||
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||
w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
@ -442,102 +457,10 @@ def test_fused_marlin_moe(
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
num_bits=num_bits,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@pytest.mark.parametrize("topk", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype, group_size: int,
|
||||
act_order: bool, num_bits: int,
|
||||
has_zp: bool, is_k_full: bool):
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if has_zp:
|
||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
quant_type = scalar_types.uint4b8 \
|
||||
if num_bits == 4 else scalar_types.uint8b128
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweight_l = []
|
||||
scales_l = []
|
||||
zeros_l = []
|
||||
g_idx_l = []
|
||||
sort_indices_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
if has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zeros_l.append(zeros)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweight_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
||||
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(
|
||||
a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
g_idx=g_idx,
|
||||
sort_indices=sort_indices,
|
||||
w_zeros=zeros,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref, score, topk)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
|
@ -1,164 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test AWQ with fused MoE Marlin kernels.
|
||||
|
||||
Run `pytest tests/kernels/test_awq_marlin.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [2, 6]
|
||||
GROUP_SIZES = [-1, 32, 128]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.skipif(not (ops.supports_moe_ops
|
||||
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
def test_fused_marlin_moe_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref1_l = []
|
||||
qweights1_l = []
|
||||
scales1_l = []
|
||||
zp1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref1_l.append(w_ref1)
|
||||
qweights1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zp1_l.append(zp1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweights1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
zp1 = stack_and_dev(zp1_l)
|
||||
|
||||
w_ref2_l = []
|
||||
qweights2_l = []
|
||||
scales2_l = []
|
||||
zp2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref2_l.append(w_ref2)
|
||||
qweights2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zp2_l.append(zp2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweights2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
zp2 = stack_and_dev(zp2_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, False)
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_zeros=zp1,
|
||||
w2_zeros=zp2,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
|
||||
score, topk, None)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
def test_single_marlin_moe_multiply_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweights_l = []
|
||||
scales_l = []
|
||||
zp_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
w_ref, qweight, scales, zp = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref_l.append(w_ref)
|
||||
qweights_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zp_l.append(zp)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweights_l).contiguous()
|
||||
scales = stack_and_dev(scales_l).contiguous()
|
||||
zp = stack_and_dev(zp_l).contiguous()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
w_zeros=zp,
|
||||
num_bits=num_bits)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.qqq import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_permute_scales, query_marlin_supported_quant_types)
|
||||
marlin_make_workspace_new, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
pack_fp8_to_int32)
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||
marlin_weights)
|
||||
@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16):
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
@ -220,38 +221,50 @@ def test_gptq_marlin_gemm(
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, False,
|
||||
use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", [8])
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_fp8_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
# WEIGHTS
|
||||
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=-1)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.fp8_marlin_gemm,
|
||||
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
|
||||
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=a_input,
|
||||
b_q_weight=marlin_qweight,
|
||||
b_scales=marlin_scales,
|
||||
workspace=workspace.scratch,
|
||||
num_bits=num_bits,
|
||||
size_m=a_input.shape[0],
|
||||
size_n=b_weight.shape[1],
|
||||
size_k=a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(a_input, b_weight)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@ -432,25 +371,23 @@ def test_awq_marlin_gemm(
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
is_k_full = True
|
||||
has_zp = True
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
@ -508,23 +445,22 @@ def test_hqq_marlin_gemm(
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(b_weight.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_w_q,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[0],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=True,
|
||||
)
|
||||
@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input():
|
||||
b_weight, quant_type, group_size, False)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=False,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
|
@ -325,18 +325,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
|
||||
@register_fake("_C::gptq_marlin_gemm")
|
||||
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
b_q_type_id: int,
|
||||
size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt,
|
||||
is_k_full: bool,
|
||||
has_zp: bool = False,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False) -> torch.Tensor:
|
||||
@ -407,14 +407,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
dtype=codebooks.dtype,
|
||||
device=codebooks.device)
|
||||
|
||||
@register_fake("_C::fp8_marlin_gemm")
|
||||
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
num_bits: int, size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||
|
||||
@register_fake("_C::machete_mm")
|
||||
def machete_mm_fake(
|
||||
a: torch.Tensor,
|
||||
@ -815,35 +807,26 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
||||
|
||||
|
||||
def gptq_marlin_gemm(a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
is_k_full: bool,
|
||||
has_zp: bool = False,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, b_q_type.id,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
has_zp, use_atomic_add,
|
||||
use_fp32_reduce, is_zp_float)
|
||||
|
||||
|
||||
# fp8 marlin
|
||||
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
num_bits: int, size_m: int, size_n: int,
|
||||
size_k: int) -> torch.Tensor:
|
||||
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
||||
num_bits, size_m, size_n, size_k)
|
||||
use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float)
|
||||
|
||||
|
||||
# machete
|
||||
|
@ -7,163 +7,13 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||
from vllm.scalar_type import scalar_types
|
||||
moe_align_block_size, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
if has_zp:
|
||||
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
def single_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx: Optional[torch.Tensor] = None,
|
||||
sort_indices: Optional[torch.Tensor] = None,
|
||||
w_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes the multiplication of hidden_states with expert
|
||||
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
||||
Its purpose is testing and debugging the fused MoE kernel.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
|
||||
- w (torch.Tensor): The set of expert weights.
|
||||
- scales (torch.Tensor): The quantization scales.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
|
||||
- sort_indices (Optional[torch.Tensor]): Optional act_order input
|
||||
permutation.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
|
||||
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w.is_contiguous(), "Expert weights must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w.shape[0]
|
||||
N = w.shape[2] // (num_bits // 2)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
# This might not be an optimal config for a single MMM
|
||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
||||
w.shape,
|
||||
w.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
is_marlin=True)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config['BLOCK_SIZE_M']
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, block_size_m, E, expert_map)
|
||||
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * \
|
||||
(sorted_token_ids.size(0) // block_size_m)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms)
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type = get_scalar_type(num_bits, w_zeros is not None)
|
||||
intermediate_cache = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
ops.moe_wna16_marlin_gemm(hidden_states,
|
||||
intermediate_cache,
|
||||
w,
|
||||
scales,
|
||||
w_zeros,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=scalar_type,
|
||||
size_m=M,
|
||||
size_n=N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False)
|
||||
intermediate_cache = intermediate_cache.view(-1, topk, N)
|
||||
|
||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||
|
||||
|
||||
def single_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx: Optional[torch.Tensor] = None,
|
||||
sort_indices: Optional[torch.Tensor] = None,
|
||||
w_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="single_marlin_moe",
|
||||
op_func=single_marlin_moe,
|
||||
mutates_args=[],
|
||||
fake_impl=single_marlin_moe_fake,
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -172,6 +22,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
@ -181,7 +32,6 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
@ -211,6 +61,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn
|
||||
]
|
||||
|
||||
int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8]
|
||||
num_bits = 4 if quant_type in int4_scalar_types else 8
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[
|
||||
0], "Number of tokens mismatch"
|
||||
@ -248,18 +107,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
expert_map)
|
||||
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * \
|
||||
(sorted_token_ids.size(0) // block_size_m)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
@ -276,6 +124,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||
|
||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||
use_atomic_add = hidden_states.dtype == torch.half or \
|
||||
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
|
||||
@ -296,7 +145,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=scalar_type1,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
@ -328,7 +177,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=scalar_type2,
|
||||
b_q_type=quant_type,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
@ -351,6 +200,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
@ -360,7 +210,6 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
marlin_make_empty_g_idx, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
self.quant_type = scalar_types.uint4
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
workspace=layer.workspace)
|
||||
|
@ -55,7 +55,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
@ -68,6 +68,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
|
@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported, maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.use_marlin = False
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
# Default to using per_token quantization if cutlass is supported
|
||||
use_per_token_if_dynamic=cutlass_fp8_supported())
|
||||
@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.block_quant:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
size_k_first = True
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
size_k_first = False
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
assert weight_scale.numel() == 1
|
||||
weight_scale = convert_to_channelwise(
|
||||
weight_scale.expand(len(layer.logical_widths)),
|
||||
layer.logical_widths)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
if not self.use_marlin:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
verify_marlin_supported)
|
||||
marlin_make_workspace_new, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError(
|
||||
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
|
@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
||||
device)
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
has_zp=self.config.zero_points,
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
|
@ -7,12 +7,15 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
from .quant_utils import pack_cols, unpack_cols
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# For binary size and compile time, we don't support the same types for with and
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: bool,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
return [scalar_types.uint4]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn]
|
||||
return res
|
||||
|
||||
|
||||
def _check_marlin_supported(
|
||||
@ -62,7 +68,7 @@ def _check_marlin_supported(
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
has_zp, True, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
||||
@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_make_workspace_new(device: torch.device,
|
||||
max_blocks_per_sm: int = 1) -> torch.Tensor:
|
||||
# In the new marlin kernel, we use the num of threadblocks as workspace
|
||||
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
return torch.zeros(sms * max_blocks_per_sm,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return output
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add(device, dtype):
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
logger.info_once(
|
||||
"You are running Marlin kernel with bf16 on GPUs before SM90. "
|
||||
"You can consider change to fp16 to achieve better performance "
|
||||
"if possible.")
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add_env():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
return
|
||||
logger.info_once(
|
||||
"Marlin kernel can achieve better performance for small size_n "
|
||||
"with experimental use_atomic_add feature. "
|
||||
"You can consider set environment variable "
|
||||
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
|
||||
|
||||
|
||||
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
||||
dtype: torch.dtype) -> bool:
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
if n >= 2048 or k < 2048 or device.type != "cuda":
|
||||
return False
|
||||
|
||||
# disable atomicAdd reduce by default,
|
||||
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
maybe_warn_marlin_atomic_add_env()
|
||||
return False
|
||||
|
||||
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
maybe_warn_marlin_atomic_add(device, dtype)
|
||||
return False
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
return n < 2048 and k >= 2048
|
||||
return True
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
|
||||
wtype: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
has_zp: bool,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
@ -6,9 +6,11 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
workspace=workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||
n=size_n,
|
||||
k=size_k,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
workspace=workspace,
|
||||
b_q_type=scalar_types.float8_e4m3fn,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
strategy: str = "tensor") -> None:
|
||||
size_k_first: bool = True) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
if size_k_first:
|
||||
assert layer.weight.shape == (part_size_k, part_size_n)
|
||||
else:
|
||||
assert layer.weight.shape == (part_size_n, part_size_k)
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace(part_size_n, device)
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
|
||||
layer.weight),
|
||||
perm=torch.empty(0,
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
|
||||
if not size_k_first:
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
# Permute scales
|
||||
if "weight_scale" in dir(layer):
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
elif "weight_scale_inv" in dir(layer):
|
||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||
del layer.weight_scale_inv
|
||||
|
||||
if layer.weight_block_size is None:
|
||||
group_size = -1
|
||||
else:
|
||||
group_size = layer.weight_block_size[1]
|
||||
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if layer.weight_block_size is None:
|
||||
if scales.nelement() == 1:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (1, 1) =>(repeat)=> (1, size_n)
|
||||
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
|
||||
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
|
||||
assert part_size_n % scales.nelement() == 0
|
||||
s_size = scales.nelement()
|
||||
# tensor-wise quantization (for gate-up proj)
|
||||
# -> channel-wise quantization
|
||||
# (1, s_size) =>(repeat)=> (1, size_n)
|
||||
scales = scales.view(1, s_size)
|
||||
scales = scales.repeat_interleave(part_size_n // s_size, 1)
|
||||
else:
|
||||
# channel-wise quantization
|
||||
# (1, size_n)
|
||||
scales = scales.view(1, part_size_n)
|
||||
else:
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.T.repeat_interleave(block_n, 1)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[:, :part_size_n]
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1)
|
||||
group_size=group_size)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
size_k_first: bool = True) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
for name in ["w13_weight", "w2_weight"]:
|
||||
weight = getattr(layer, name)
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
if size_k_first:
|
||||
assert weight.shape == (e, size_k, size_n)
|
||||
else:
|
||||
assert weight.shape == (e, size_n, size_k)
|
||||
|
||||
for i in range(e):
|
||||
qweight = pack_fp8_to_int32(weight[i], size_k_first)
|
||||
if not size_k_first:
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
setattr(layer, name, weight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
if layer.weight_block_size is None:
|
||||
group_size = -1
|
||||
else:
|
||||
group_size = layer.weight_block_size[1]
|
||||
|
||||
for name in ["w13", "w2"]:
|
||||
if name + "_weight_scale" in dir(layer):
|
||||
new_name = name + "_weight_scale"
|
||||
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||
delattr(layer, new_name)
|
||||
elif name + "_weight_scale_inv" in dir(layer):
|
||||
new_name = name + "_weight_scale_inv"
|
||||
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||
delattr(layer, new_name)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
# marlin kernel only support channel-wise and group-wise quantization
|
||||
# we need to convert the scales
|
||||
if layer.weight_block_size is None:
|
||||
if scales.nelement() == e:
|
||||
# tensor-wise quantization -> channel-wise quantization
|
||||
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
||||
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
|
||||
elif scales.nelement() > e and scales.nelement() != e * size_n:
|
||||
assert (e * size_n) % scales.nelement() == 0
|
||||
s_size = scales.nelement() // e
|
||||
# tensor-wise quantization (for gate-up proj)
|
||||
# -> channel-wise quantization
|
||||
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
|
||||
scales = scales.view(e, 1, s_size)
|
||||
scales = scales.repeat_interleave(size_n // s_size, 2)
|
||||
else:
|
||||
# channel-wise quantization
|
||||
# (e, 1, size_n)
|
||||
scales = scales.view(e, 1, size_n)
|
||||
else:
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[..., :size_n].contiguous()
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i],
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
|
||||
size_k_first: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements)
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
assert fp8_tensor.shape[0] % 4 == 0
|
||||
assert fp8_tensor.ndim == 2
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
|
||||
fp8_tensor = fp8_tensor.contiguous()
|
||||
# fp8_tensor is contiguous and have shape (N, K) now
|
||||
# with `.view(torch.int32)`, it become (N, K // 4)
|
||||
int32_tensor = fp8_tensor.view(torch.int32)
|
||||
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||
def marlin_quant_fp8_torch(weight, group_size):
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
return packed.view(fp8_tensor.shape[0] // 4,
|
||||
*fp8_tensor.shape[1:]).contiguous()
|
||||
if group_size != -1:
|
||||
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
|
||||
repeated_scales = scales.repeat_interleave(group_size, 1)
|
||||
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
||||
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||
else:
|
||||
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
|
||||
repeated_scales = scales.repeat_interleave(size_k, 1)
|
||||
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
||||
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||
|
||||
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_weight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales
|
||||
|
@ -6,6 +6,8 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
_SCALAR_TYPES_ID_MAP = {}
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
@ -158,6 +160,8 @@ class ScalarType:
|
||||
assert offset <= 64, \
|
||||
f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
@ -295,6 +299,13 @@ class ScalarType:
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(
|
||||
f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
|
Reference in New Issue
Block a user