[Kernel] fp4 marlin kernel (#17687)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-11 10:58:49 +08:00
committed by GitHub
parent ca66a1674c
commit d74e5f37bc
21 changed files with 1216 additions and 331 deletions

View File

@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;

View File

@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
@ -39,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
@ -72,6 +75,12 @@ def generate_new_kernels():
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16

View File

@ -7,17 +7,18 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {

View File

@ -301,9 +301,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// only)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
@ -341,6 +343,16 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits();
@ -348,7 +360,8 @@ __global__ void Marlin(
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@ -460,9 +473,16 @@ __global__ void Marlin(
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
global_scale,
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
} else {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
}
}
}
}
@ -493,6 +513,11 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
}
if constexpr (w_type == vllm::kFE2M1f) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) {
@ -606,7 +631,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
@ -664,7 +689,8 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
}
}
@ -688,10 +714,20 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
else
@ -801,7 +837,7 @@ __global__ void Marlin(
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < act_s_max_num_groups) {
if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
@ -1021,12 +1057,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
}
}
@ -1199,22 +1242,7 @@ __global__ void Marlin(
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
if constexpr (has_zp && is_zp_float || !has_zp) {
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
} else {
static_assert(has_zp && !is_zp_float);
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if constexpr (w_type_id == vllm::kU4.id()) {
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
} else if constexpr (w_type_id == vllm::kU8.id()) {
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
}
}
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
};
// Execute the actual tensor core matmul of a sub-tile.
@ -1244,13 +1272,23 @@ __global__ void Marlin(
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
}
if constexpr (has_zp && is_zp_float) {
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}
if constexpr (w_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
@ -1259,7 +1297,10 @@ __global__ void Marlin(
FragB frag_b1;
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
@ -1272,6 +1313,11 @@ __global__ void Marlin(
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
@ -1279,7 +1325,8 @@ __global__ void Marlin(
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
@ -1287,7 +1334,7 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && group_blocks != -1) {
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
@ -1554,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) {
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
}
if constexpr (w_type == vllm::kFE2M1f) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
@ -1648,7 +1702,9 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
fetch_col_scale_to_shared();
if constexpr (!dequant_skip_flop) {
fetch_col_scale_to_shared();
}
}
}
fetch_to_shared(i, i, i < slice_iters, i);
@ -1737,7 +1793,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@ -1747,7 +1804,8 @@ __global__ void Marlin(
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
@ -1771,7 +1829,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) {
w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {

View File

@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
BIGGROUP_GET_IF(vllm::kFE4M3fn)
FP4_GET_IF(vllm::kFE2M1f)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(q_type == vllm::kU4,
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn,
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = ",
q_type.str());
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
}
}
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm(
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn,
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = ",
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
}
@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,

View File

@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"

View File

@ -1,3 +1,67 @@
/*
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
The process of fast dequantization can be summarized as a combination
of bitwise operations and floating-point computations:
weight =>(bit_op / bitwise operations)=>
f16_value =>(flop / floating-point computation)=>
dequantized_weight
Since the dequantized weights typically require subtracting the zero point and
applying a scale factor, the floating-point computation step can be fused with
the zero-point subtraction and scaling operations.
The following are the parts that need to be modified for the fused operation
of zero-point subtraction and scaling.
## INT4 => FP16/BF16 or INT8 => FP16
The floating-point computation is `__hsub2`
If has zero points:
flop(bit_op(weight)) - flop(bit_op(zp))
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
= bit_op(weight) - bit_op(zp)
so we don't need additional modification.
If has float zero points:
flop(bit_op(weight)) - fzp
= sub(bit_op(weight), bias) - fzp
= bit_op(weight) - (fzp + bias)
where the `fzp + bias` can be computed at weight loading. But this
may have accuracy issue, so we should not use this in most cases.
If has not zero points:
scale(flop(bit_op(weight)))
= scale(sub(bit_op(weight), bias))
= scale(bit_op(weight)) - scale(bias)
= fma(bit_op(weight), scale_factor, scale(bias))
where the `scale(bias)` can be cached. But this may have accuracy issue,
so we should not use this in most cases.
## INT8 => BF16
INT8 => BF16 is a special case, it use byte_perm instead of flop.
We cannot fused byte_perm with scaling.
## FP4/FP8 => FP16/BF16
scale(flop(bit_op(weight)))
= scale(mul(bit_op(weight), multiplier))
= mul(bit_op(weight), scale_factor * multiplier)
where `scale_factor * multiplier` can be computed at weight loading.
*/
#include "marlin_dtypes.cuh"
@ -27,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) {
return res;
}
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
@ -40,7 +105,22 @@ __device__ inline void dequant(int q, scalar_t2* frag_b);
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
half2* frag_b) {
const int MASK = 0x000f000f;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@ -62,7 +142,14 @@ __device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
half2* frag_b) {
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@ -84,7 +171,7 @@ __device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
@ -96,39 +183,36 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
static constexpr uint32_t SUB = 0x43084308;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43004300;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
//
@ -140,8 +224,8 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@ -149,33 +233,42 @@ __device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
frag_b[0] = __hsub2(frag_b[0],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
frag_b[1] = __hsub2(frag_b[1],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
half2* frag_b) {
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
}
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
template <>
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
half2* frag_b) {
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
frag_b[0] = __hsub2(frag_b[0],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
frag_b[1] = __hsub2(frag_b[1],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
@ -200,7 +293,7 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
@ -225,22 +318,30 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
int q, half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
@ -248,28 +349,36 @@ __device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
@ -281,9 +390,116 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
half2* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and BF16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
int Out2 = (q & 0xFF00FF00) >> 1;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif

View File

@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
(128, 64, 128)]
@ -40,7 +43,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
@ -73,6 +76,12 @@ def generate_new_kernels():
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16

View File

@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
FP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
@ -434,8 +454,8 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, int lda, void* workspace,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k_init,
@ -446,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn,
"q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
"has_zp = False. Got = ",
q_type.str());
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -483,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
@ -601,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new);
// clang-format on
@ -617,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
@ -759,6 +782,17 @@ torch::Tensor gptq_marlin_gemm(
}
}
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
@ -774,8 +808,9 @@ torch::Tensor gptq_marlin_gemm(
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn,
"b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
}
@ -820,22 +855,36 @@ torch::Tensor gptq_marlin_gemm(
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {

View File

@ -7,13 +7,14 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
int prob_k, int lda, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16

View File

@ -292,9 +292,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// only)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
@ -325,6 +327,21 @@ __global__ void Marlin(
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
if constexpr (w_type == vllm::kFE2M1f) {
uint16_t val = scale2_ptr[0];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
constexpr bool has_act_order = group_blocks == 0;
constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
@ -481,7 +498,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
@ -540,7 +557,8 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
}
}
@ -564,10 +582,20 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
else
@ -681,7 +709,7 @@ __global__ void Marlin(
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < act_s_max_num_groups) {
if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
@ -887,12 +915,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
}
}
@ -1065,22 +1100,7 @@ __global__ void Marlin(
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
if constexpr (has_zp && is_zp_float || !has_zp) {
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
} else {
static_assert(has_zp && !is_zp_float);
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if constexpr (w_type_id == vllm::kU4.id()) {
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
} else if constexpr (w_type_id == vllm::kU8.id()) {
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
}
}
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
};
// Execute the actual tensor core matmul of a sub-tile.
@ -1110,13 +1130,23 @@ __global__ void Marlin(
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
}
if constexpr (has_zp && is_zp_float) {
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}
if constexpr (w_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
@ -1125,7 +1155,10 @@ __global__ void Marlin(
FragB frag_b1;
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
@ -1138,6 +1171,11 @@ __global__ void Marlin(
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
@ -1145,7 +1183,8 @@ __global__ void Marlin(
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
@ -1153,7 +1192,7 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && group_blocks != -1) {
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
@ -1408,10 +1447,15 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) {
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
}
if constexpr (w_type == vllm::kFE2M1f) {
res = __hmul2(res, global_scale);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
@ -1488,7 +1532,9 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
fetch_col_scale_to_shared();
if constexpr (!dequant_skip_flop) {
fetch_col_scale_to_shared();
}
}
}
fetch_to_shared(i, i, i < slice_iters);
@ -1563,7 +1609,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@ -1573,7 +1620,8 @@ __global__ void Marlin(
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
@ -1597,7 +1645,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) {
w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {

View File

@ -292,8 +292,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
{stride_tag});

View File

@ -16,6 +16,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_fp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@ -286,21 +288,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype])
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [4, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
dtype_list = [torch.half, torch.bfloat16]
group_size_list = [-1, 16, 32, 128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
scalar_types.float8_e4m3fn,
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
is_k_full_list = [True, False]
all_combinations = itertools.product(m_list, n_list, k_list, e_list,
topk_list, ep_size_list, dtype_list,
group_size_list, act_order_list,
quant_type_list, is_k_full_list)
def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
quant_type, is_k_full):
if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False
# Filter act_order
if act_order:
if group_size in (-1, k, n):
return False
if quant_type not in [scalar_types.uint4b8]:
return False
elif not is_k_full:
return False
return True
cases = []
for case in all_combinations:
if is_invalid(*case):
cases.append(case)
return cases
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize("m", [1, 123, 666])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("ep_size", [1, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("quant_type", [
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types.float8_e4m3fn
])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases())
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
@ -338,6 +383,11 @@ def test_fused_marlin_moe(
if not is_k_full:
return
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
@ -355,12 +405,27 @@ def test_fused_marlin_moe(
w_ref1_l = []
qweight1_l = []
scales1_l = []
global_scale1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
if has_zp:
if quant_type == scalar_types.float4_e2m1f:
w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_fp4_like(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
@ -368,7 +433,7 @@ def test_fused_marlin_moe(
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
elif quant_type != scalar_types.float8_e4m3fn:
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
@ -379,16 +444,11 @@ def test_fused_marlin_moe(
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
else:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
@ -396,12 +456,27 @@ def test_fused_marlin_moe(
w_ref2_l = []
qweight2_l = []
scales2_l = []
global_scale2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if has_zp:
if quant_type == scalar_types.float4_e2m1f:
w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_fp4_like(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
@ -409,7 +484,7 @@ def test_fused_marlin_moe(
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
elif quant_type != scalar_types.float8_e4m3fn:
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
@ -420,24 +495,18 @@ def test_fused_marlin_moe(
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
else:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
@ -452,6 +521,8 @@ def test_fused_marlin_moe(
topk_ids,
global_num_experts=e,
expert_map=e_map,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,

View File

@ -20,6 +20,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_scales,
query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@ -190,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
"group_size",
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@ -210,6 +213,7 @@ def test_gptq_marlin_gemm(
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
@ -220,6 +224,8 @@ def test_gptq_marlin_gemm(
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0:
return
@ -227,7 +233,15 @@ def test_gptq_marlin_gemm(
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
if quant_type == scalar_types.float8_e4m3fn:
if quant_type == scalar_types.float4_e2m1f:
if group_size != 16 or act_order:
return
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
b_weight.T, group_size)
g_idx = None
sort_indices = None
marlin_zp = None
elif quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
@ -236,26 +250,39 @@ def test_gptq_marlin_gemm(
b_weight.T, group_size)
g_idx = None
sort_indices = None
marlin_zp = None
marlin_s2 = None
elif has_zp:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
@ -339,67 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_awq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
is_k_full = True
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@ -452,6 +418,7 @@ def test_hqq_marlin_gemm(
None,
marlin_w_q,
marlin_s,
None,
marlin_zp,
g_idx,
g_idx_sort_indices,
@ -564,6 +531,7 @@ def test_marlin_gemm_subset_input():
None,
marlin_q_w,
marlin_s,
None,
marlin_zp,
g_idx,
sort_indices,

View File

@ -333,6 +333,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
@ -866,6 +867,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
@ -878,9 +880,10 @@ def gptq_marlin_gemm(a: torch.Tensor,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
global_scale, b_zeros, g_idx, perm,
workspace, b_q_type.id, size_m,
size_n, size_k, is_k_full,
use_atomic_add, use_fp32_reduce,
is_zp_float)
@ -1381,6 +1384,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
@ -1395,11 +1399,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
is_zp_float)
input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx,
perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded,
topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep,
b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add,
use_fp32_reduce, is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

View File

@ -25,6 +25,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type_id: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
@ -64,11 +66,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types.float8_e4m3fn
scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
]
int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8]
num_bits = 4 if quant_type in int4_scalar_types else 8
bit4_scalar_types = [
scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
]
num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[
@ -133,6 +137,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache1,
w1,
w1_scale,
global_scale1,
w1_zeros,
g_idx1,
sort_indices1,
@ -165,6 +170,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache3,
w2,
w2_scale,
global_scale2,
w2_zeros,
g_idx2,
sort_indices2,
@ -202,6 +208,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,

View File

@ -304,8 +304,10 @@ class HQQMarlinMethod(LinearMethodBase):
marlin_out = ops.gptq_marlin_gemm(
x,
None,
layer.marlin_qweight,
scales,
None,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
@ -315,7 +317,7 @@ class HQQMarlinMethod(LinearMethodBase):
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
True, # has_zp
False, # use atomic add
True, # use 32-bit reduce
True, # use float zp
)

View File

@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -24,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -196,7 +200,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 100
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
@ -278,9 +282,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
if not self.cutlass_nvfp4_supported:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and above.")
if is_fp4_marlin_supported():
self.use_marlin = True
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
def create_weights(
self,
@ -392,12 +402,29 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
output_dtype = x.dtype
# for input only the contracting dimension has a constraint.
@ -434,6 +461,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
if not self.cutlass_nvfp4_supported:
if is_fp4_marlin_supported():
self.use_marlin = True
else:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -442,6 +479,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
layer.num_experts = num_experts
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
@ -594,7 +633,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
return
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
del layer.g1_alphas
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
del layer.w13_blockscale_swizzled
del layer.w2_blockscale_swizzled
def apply(
self,
@ -614,6 +661,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
if self.use_marlin:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "

View File

@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(
has_zp: bool,
has_zp: Optional[bool] = None,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
if device_capability < 80:
return []
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = query_marlin_supported_quant_types(False, include_fp_type,
device_capability)
types1 = query_marlin_supported_quant_types(True, include_fp_type,
device_capability)
return types0 + types1
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4]
@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
# GPTQ style, unsigned + symmetric bias
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn]
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
return res
@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,

View File

@ -0,0 +1,277 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
logger = init_logger(__name__)
def is_fp4_marlin_supported():
return current_platform.has_device_capability(80)
def fp4_marlin_process_scales(marlin_scales):
assert (marlin_scales >= 0).all()
# convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half)
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1)
# We assume that weight_scale (FP8-S1E4M3) is always greater
# than or equal to 0. So we can convert
# (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
# After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
# when weight_scale > 0. This allows us to have an exponent bias
# closer to zero after dequantization.
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
marlin_scales = marlin_scales[:, 1::2].contiguous()
return marlin_scales
def fp4_marlin_process_global_scale(global_scale):
assert global_scale.dtype in [torch.half, torch.bfloat16]
fp4_exponent = 2
if global_scale.dtype == torch.half:
target_exponent = 5
elif global_scale.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
return global_scale * (2.0**(exponent_bias - 7))
def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
# For GPUs that lack FP4 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP4 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=size_n,
k=size_k,
device=input.device,
dtype=input.dtype)
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_scales=weight_scale,
global_scale=weight_scale_2,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float4_e2m1f,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
param_dtype = layer.params_dtype
assert layer.weight.shape == (part_size_n, part_size_k // 2)
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace_new(device)
# WEIGHT
# Repack weights to marlin format
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = layer.weight.view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=4)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Permute scales
weight_scale = layer.weight_scale.T.to(param_dtype)
weight_scale = marlin_permute_scales(s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=16)
weight_scale = fp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
return
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
# WORKSPACE
device = layer.w13_weight.device
param_dtype = layer.params_dtype
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
assert weight.shape == (e, size_n, size_k // 2)
for i in range(e):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight)
# WEIGHT SCALES
# Permute scales
for name in ["w13", "w2"]:
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i].T,
size_k=size_k,
size_n=size_n,
group_size=16)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
def rand_marlin_weight_fp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
global_scale = scales.max() / 448
scales = (scales / global_scale).to(torch.float8_e4m3fn)
fp4_weight = torch.randint(0,
256, (size_n, size_k // 2),
dtype=torch.uint8,
device=weight.device)
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
((fp4_weight & 0b01110000) >> 2))
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
fp4_weight2 = fp4_weight << 4
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
((fp4_weight2 & 0b01110000) >> 2))
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
weight_ref = torch.cat(
[fp4_weight_part_2.unsqueeze(2),
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
scales.repeat_interleave(group_size, 1).to(weight.dtype)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=4,
)
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
return weight_ref.T, marlin_qweight, marlin_scales, global_scale

View File

@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
return current_platform.has_device_capability(80)
def fp8_fused_exponent_bias_into_scales(scales):
fp8_exponent = 4
if scales.dtype == torch.half:
target_exponent = 5
elif scales.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
s = torch.ones_like(scales) * 2
s = s**exponent_bias
return scales * s
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
c=None,
b_q_weight=weight,
b_scales=weight_scale,
global_scale=None,
b_zeros=None,
g_idx=None,
perm=None,
@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.T.contiguous()
block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1)
scales = scales.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
size_n=size_n,
group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales