mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel] fp4 marlin kernel (#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f =
|
||||
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
|
@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f"
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
@ -39,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
@ -72,6 +75,12 @@ def generate_new_kernels():
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
|
@ -7,17 +7,18 @@
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
@ -301,9 +301,11 @@ __global__ void Marlin(
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
// only)
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
|
||||
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
|
||||
@ -341,6 +343,16 @@ __global__ void Marlin(
|
||||
extern __shared__ int4 sh[];
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
|
||||
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||
@ -348,7 +360,8 @@ __global__ void Marlin(
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
const int group_size =
|
||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
|
||||
const int scales_expert_stride =
|
||||
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
@ -460,9 +473,16 @@ __global__ void Marlin(
|
||||
if (mul_topk_weights) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sh_block_topk_weights[tid4 * 4 + i] =
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
|
||||
global_scale,
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
|
||||
} else {
|
||||
sh_block_topk_weights[tid4 * 4 + i] =
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -493,6 +513,11 @@ __global__ void Marlin(
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[expert_id];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
|
||||
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
||||
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
|
||||
if constexpr (has_zp) {
|
||||
@ -606,7 +631,7 @@ __global__ void Marlin(
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
|
||||
: 1;
|
||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||
int s_gl_rd_delta = s_gl_stride;
|
||||
@ -664,7 +689,8 @@ __global__ void Marlin(
|
||||
if constexpr (group_blocks == -1) {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
|
||||
(w_type == vllm::kFE2M1f ? 2 : 1) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
@ -688,10 +714,20 @@ __global__ void Marlin(
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
int s_sh_rd;
|
||||
if constexpr (group_blocks != -1)
|
||||
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 &&
|
||||
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
else
|
||||
@ -801,7 +837,7 @@ __global__ void Marlin(
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups < act_s_max_num_groups) {
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
@ -1021,12 +1057,19 @@ __global__ void Marlin(
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
int cur_group_id =
|
||||
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1199,22 +1242,7 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||
} else {
|
||||
static_assert(has_zp && !is_zp_float);
|
||||
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||
// If (has_zp && !is_zp_float),
|
||||
// we use not-zp version `dequant` function
|
||||
// to improve numerical accuracy.
|
||||
// Since both weight and zero point are dequanted using this logic,
|
||||
// the final dequanted weight would be correct.
|
||||
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||
}
|
||||
}
|
||||
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
@ -1244,13 +1272,23 @@ __global__ void Marlin(
|
||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||
}
|
||||
}
|
||||
if constexpr (has_zp && is_zp_float) {
|
||||
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -1259,7 +1297,10 @@ __global__ void Marlin(
|
||||
FragB frag_b1;
|
||||
int b_quant_0, b_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
|
||||
b_quant_1 = frag_b_quant[k2][0][j];
|
||||
b_quant_0 = b_quant_1 << 8;
|
||||
} else if constexpr (w_type.size_bits() == 4) {
|
||||
b_quant_0 = frag_b_quant[k2][0][j];
|
||||
b_quant_1 = b_quant_0 >> 8;
|
||||
} else {
|
||||
@ -1272,6 +1313,11 @@ __global__ void Marlin(
|
||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||
|
||||
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
static_assert(group_blocks != -1);
|
||||
@ -1279,7 +1325,8 @@ __global__ void Marlin(
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
||||
group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2 s2 = Dtype::nums2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
||||
@ -1287,7 +1334,7 @@ __global__ void Marlin(
|
||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && group_blocks != -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zp[j] = __hmul2(frag_zp[j],
|
||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
@ -1554,10 +1601,17 @@ __global__ void Marlin(
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 && !has_zp) {
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if (!mul_topk_weights) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
||||
@ -1648,7 +1702,9 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_col_zp_to_shared();
|
||||
fetch_col_scale_to_shared();
|
||||
if constexpr (!dequant_skip_flop) {
|
||||
fetch_col_scale_to_shared();
|
||||
}
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters, i);
|
||||
@ -1737,7 +1793,8 @@ __global__ void Marlin(
|
||||
bool last = slice_idx == slice_count - 1;
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
@ -1747,7 +1804,8 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@ -1771,7 +1829,8 @@ __global__ void Marlin(
|
||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||
// overflow in fp16)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 8 && !has_zp) {
|
||||
w_type.size_bits() == 8 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
|
@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
|
||||
@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(q_type == vllm::kU4,
|
||||
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn,
|
||||
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||
"False. Got = ",
|
||||
q_type.str());
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn,
|
||||
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||
"False. Got = ",
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
|
@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
|
@ -1,3 +1,67 @@
|
||||
/*
|
||||
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
||||
|
||||
The process of fast dequantization can be summarized as a combination
|
||||
of bitwise operations and floating-point computations:
|
||||
|
||||
weight =>(bit_op / bitwise operations)=>
|
||||
f16_value =>(flop / floating-point computation)=>
|
||||
dequantized_weight
|
||||
|
||||
Since the dequantized weights typically require subtracting the zero point and
|
||||
applying a scale factor, the floating-point computation step can be fused with
|
||||
the zero-point subtraction and scaling operations.
|
||||
|
||||
The following are the parts that need to be modified for the fused operation
|
||||
of zero-point subtraction and scaling.
|
||||
|
||||
## INT4 => FP16/BF16 or INT8 => FP16
|
||||
|
||||
The floating-point computation is `__hsub2`
|
||||
|
||||
If has zero points:
|
||||
|
||||
flop(bit_op(weight)) - flop(bit_op(zp))
|
||||
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
||||
= bit_op(weight) - bit_op(zp)
|
||||
|
||||
so we don't need additional modification.
|
||||
|
||||
If has float zero points:
|
||||
|
||||
flop(bit_op(weight)) - fzp
|
||||
= sub(bit_op(weight), bias) - fzp
|
||||
= bit_op(weight) - (fzp + bias)
|
||||
|
||||
where the `fzp + bias` can be computed at weight loading. But this
|
||||
may have accuracy issue, so we should not use this in most cases.
|
||||
|
||||
If has not zero points:
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(sub(bit_op(weight), bias))
|
||||
= scale(bit_op(weight)) - scale(bias)
|
||||
= fma(bit_op(weight), scale_factor, scale(bias))
|
||||
|
||||
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
||||
so we should not use this in most cases.
|
||||
|
||||
|
||||
## INT8 => BF16
|
||||
|
||||
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
||||
We cannot fused byte_perm with scaling.
|
||||
|
||||
|
||||
## FP4/FP8 => FP16/BF16
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(mul(bit_op(weight), multiplier))
|
||||
= mul(bit_op(weight), scale_factor * multiplier)
|
||||
|
||||
where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
*/
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
@ -27,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
@ -40,7 +105,22 @@ __device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@ -62,7 +142,14 @@ __device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@ -84,7 +171,7 @@ __device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
@ -96,39 +183,36 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC308C308;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
static constexpr uint32_t SUB = 0x43084308;
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43004300;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
//
|
||||
@ -140,8 +224,8 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
@ -149,33 +233,42 @@ __device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
frag_b[0] = __hsub2(frag_b[0],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
frag_b[1] = __hsub2(frag_b[1],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
frag_b[0] = __hsub2(frag_b[0],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
frag_b[1] = __hsub2(frag_b[1],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
@ -200,7 +293,7 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
@ -225,22 +318,30 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
// Calculate MASK for extracting mantissa and exponent
|
||||
constexpr int MASK1 = 0x80000000;
|
||||
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||
// Final MASK value: 0x7F007F00
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
@ -248,28 +349,36 @@ __device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
|
||||
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
// Calculate MASK for extracting mantissa and exponent
|
||||
constexpr int MASK1 = 0x80000000;
|
||||
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||
// Final MASK value: 0x7F007F00
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
@ -281,9 +390,116 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
|
||||
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and BF16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg =
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0xFF00FF00) >> 1;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
||||
nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f"
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
||||
(128, 64, 128)]
|
||||
|
||||
@ -40,7 +43,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
@ -73,6 +76,12 @@ def generate_new_kernels():
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
|
@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
@ -434,8 +454,8 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, int lda, void* workspace,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k_init,
|
||||
@ -446,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn,
|
||||
"q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -483,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
@ -601,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
@ -617,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -759,6 +782,17 @@ torch::Tensor gptq_marlin_gemm(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
@ -774,8 +808,9 @@ torch::Tensor gptq_marlin_gemm(
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn,
|
||||
"b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
@ -820,22 +855,36 @@ torch::Tensor gptq_marlin_gemm(
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
|
@ -7,13 +7,14 @@
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
|
||||
int prob_k, int lda, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
|
@ -292,9 +292,11 @@ __global__ void Marlin(
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
// only)
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
@ -325,6 +327,21 @@ __global__ void Marlin(
|
||||
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[0];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
@ -481,7 +498,7 @@ __global__ void Marlin(
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
|
||||
: 1;
|
||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||
int s_gl_rd_delta = s_gl_stride;
|
||||
@ -540,7 +557,8 @@ __global__ void Marlin(
|
||||
if constexpr (group_blocks == -1) {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
|
||||
(w_type == vllm::kFE2M1f ? 2 : 1) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
@ -564,10 +582,20 @@ __global__ void Marlin(
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
int s_sh_rd;
|
||||
if constexpr (group_blocks != -1)
|
||||
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 &&
|
||||
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
else
|
||||
@ -681,7 +709,7 @@ __global__ void Marlin(
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups < act_s_max_num_groups) {
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
@ -887,12 +915,19 @@ __global__ void Marlin(
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
int cur_group_id =
|
||||
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1065,22 +1100,7 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||
} else {
|
||||
static_assert(has_zp && !is_zp_float);
|
||||
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||
// If (has_zp && !is_zp_float),
|
||||
// we use not-zp version `dequant` function
|
||||
// to improve numerical accuracy.
|
||||
// Since both weight and zero point are dequanted using this logic,
|
||||
// the final dequanted weight would be correct.
|
||||
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||
}
|
||||
}
|
||||
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
@ -1110,13 +1130,23 @@ __global__ void Marlin(
|
||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||
}
|
||||
}
|
||||
if constexpr (has_zp && is_zp_float) {
|
||||
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -1125,7 +1155,10 @@ __global__ void Marlin(
|
||||
FragB frag_b1;
|
||||
int b_quant_0, b_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
|
||||
b_quant_1 = frag_b_quant[k2][0][j];
|
||||
b_quant_0 = b_quant_1 << 8;
|
||||
} else if constexpr (w_type.size_bits() == 4) {
|
||||
b_quant_0 = frag_b_quant[k2][0][j];
|
||||
b_quant_1 = b_quant_0 >> 8;
|
||||
} else {
|
||||
@ -1138,6 +1171,11 @@ __global__ void Marlin(
|
||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||
|
||||
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
static_assert(group_blocks != -1);
|
||||
@ -1145,7 +1183,8 @@ __global__ void Marlin(
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
||||
group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2 s2 = Dtype::nums2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
||||
@ -1153,7 +1192,7 @@ __global__ void Marlin(
|
||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && group_blocks != -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zp[j] = __hmul2(frag_zp[j],
|
||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
@ -1408,10 +1447,15 @@ __global__ void Marlin(
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 && !has_zp) {
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
||||
@ -1488,7 +1532,9 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_col_zp_to_shared();
|
||||
fetch_col_scale_to_shared();
|
||||
if constexpr (!dequant_skip_flop) {
|
||||
fetch_col_scale_to_shared();
|
||||
}
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
@ -1563,7 +1609,8 @@ __global__ void Marlin(
|
||||
bool last = slice_idx == slice_count - 1;
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
@ -1573,7 +1620,8 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@ -1597,7 +1645,8 @@ __global__ void Marlin(
|
||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||
// overflow in fp16)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 8 && !has_zp) {
|
||||
w_type.size_bits() == 8 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
|
@ -292,8 +292,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
|
||||
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||
{stride_tag});
|
||||
|
@ -16,6 +16,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
rand_marlin_weight_fp4_like)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@ -286,21 +288,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
def marlin_moe_generate_valid_test_cases():
|
||||
import itertools
|
||||
m_list = [1, 123, 666]
|
||||
n_list = [128, 1024]
|
||||
k_list = [256, 2048]
|
||||
e_list = [4, 12]
|
||||
topk_list = [2, 3]
|
||||
ep_size_list = [1, 4]
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
group_size_list = [-1, 16, 32, 128]
|
||||
act_order_list = [True, False]
|
||||
quant_type_list = [
|
||||
scalar_types.float4_e2m1f,
|
||||
scalar_types.float8_e4m3fn,
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
is_k_full_list = [True, False]
|
||||
|
||||
all_combinations = itertools.product(m_list, n_list, k_list, e_list,
|
||||
topk_list, ep_size_list, dtype_list,
|
||||
group_size_list, act_order_list,
|
||||
quant_type_list, is_k_full_list)
|
||||
|
||||
def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
|
||||
quant_type, is_k_full):
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn and \
|
||||
group_size not in [-1, 128]:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
|
||||
return False
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
return False
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size in (-1, k, n):
|
||||
return False
|
||||
if quant_type not in [scalar_types.uint4b8]:
|
||||
return False
|
||||
elif not is_k_full:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
if is_invalid(*case):
|
||||
cases.append(case)
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize("m", [1, 123, 666])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@pytest.mark.parametrize("topk", [2, 3])
|
||||
@pytest.mark.parametrize("ep_size", [1, 4])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("quant_type", [
|
||||
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn
|
||||
])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
|
||||
"act_order, quant_type, is_k_full"),
|
||||
marlin_moe_generate_valid_test_cases())
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
@ -338,6 +383,11 @@ def test_fused_marlin_moe(
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
|
||||
return
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
return
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
@ -355,12 +405,27 @@ def test_fused_marlin_moe(
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
global_scale1_l = []
|
||||
zeros1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
if has_zp:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
w_ref1, qweight1, scales1, global_scale1 = \
|
||||
rand_marlin_weight_fp4_like(w1[i], group_size)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
global_scale1_l.append(global_scale1)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||
w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
elif has_zp:
|
||||
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
@ -368,7 +433,7 @@ def test_fused_marlin_moe(
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
elif quant_type != scalar_types.float8_e4m3fn:
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
@ -379,16 +444,11 @@ def test_fused_marlin_moe(
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||
w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
@ -396,12 +456,27 @@ def test_fused_marlin_moe(
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
global_scale2_l = []
|
||||
zeros2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
if has_zp:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
w_ref2, qweight2, scales2, global_scale2 = \
|
||||
rand_marlin_weight_fp4_like(w2[i], group_size)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
global_scale2_l.append(global_scale2)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||
w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
elif has_zp:
|
||||
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
@ -409,7 +484,7 @@ def test_fused_marlin_moe(
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
elif quant_type != scalar_types.float8_e4m3fn:
|
||||
else:
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
@ -420,24 +495,18 @@ def test_fused_marlin_moe(
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||
w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||
|
||||
@ -452,6 +521,8 @@ def test_fused_marlin_moe(
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
|
@ -20,6 +20,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@ -190,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
||||
@pytest.mark.parametrize(
|
||||
"group_size",
|
||||
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@ -210,6 +213,7 @@ def test_gptq_marlin_gemm(
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
@ -220,6 +224,8 @@ def test_gptq_marlin_gemm(
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
@ -227,7 +233,15 @@ def test_gptq_marlin_gemm(
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size != 16 or act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||
b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
@ -236,26 +250,39 @@ def test_gptq_marlin_gemm(
|
||||
b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
elif has_zp:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
|
||||
sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
@ -339,67 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
is_k_full = True
|
||||
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@ -452,6 +418,7 @@ def test_hqq_marlin_gemm(
|
||||
None,
|
||||
marlin_w_q,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@ -564,6 +531,7 @@ def test_marlin_gemm_subset_input():
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
|
@ -333,6 +333,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
@ -866,6 +867,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
@ -878,9 +880,10 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, b_q_type.id,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
|
||||
global_scale, b_zeros, g_idx, perm,
|
||||
workspace, b_q_type.id, size_m,
|
||||
size_n, size_k, is_k_full,
|
||||
use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float)
|
||||
|
||||
@ -1381,6 +1384,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
||||
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
|
||||
b_qweight: torch.Tensor, b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_qzeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
@ -1395,11 +1399,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
|
||||
use_fp32_reduce: bool,
|
||||
is_zp_float: bool) -> torch.Tensor:
|
||||
return torch.ops._moe_C.moe_wna16_marlin_gemm(
|
||||
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
|
||||
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
|
||||
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float)
|
||||
input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx,
|
||||
perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded,
|
||||
topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, is_zp_float)
|
||||
|
||||
|
||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
|
@ -25,6 +25,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
global_scale1: Optional[torch.Tensor] = None,
|
||||
global_scale2: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
@ -64,11 +66,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn
|
||||
scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
|
||||
]
|
||||
|
||||
int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8]
|
||||
num_bits = 4 if quant_type in int4_scalar_types else 8
|
||||
bit4_scalar_types = [
|
||||
scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
|
||||
]
|
||||
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[
|
||||
@ -133,6 +137,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
w1_scale,
|
||||
global_scale1,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
@ -165,6 +170,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
w2_scale,
|
||||
global_scale2,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
@ -202,6 +208,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
global_num_experts: int = -1,
|
||||
global_scale1: Optional[torch.Tensor] = None,
|
||||
global_scale2: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
|
@ -304,8 +304,10 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
|
||||
marlin_out = ops.gptq_marlin_gemm(
|
||||
x,
|
||||
None,
|
||||
layer.marlin_qweight,
|
||||
scales,
|
||||
None,
|
||||
zeros,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
@ -315,7 +317,7 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
self.output_size_per_partition,
|
||||
self.input_size_per_partition,
|
||||
True, # is_k_full
|
||||
True, # has_zp
|
||||
False, # use atomic add
|
||||
True, # use 32-bit reduce
|
||||
True, # use float zp
|
||||
)
|
||||
|
@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -24,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -196,7 +200,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@ -278,9 +282,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and above.")
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -392,12 +402,29 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_marlin:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
output_dtype = x.dtype
|
||||
|
||||
# for input only the contracting dimension has a constraint.
|
||||
@ -434,6 +461,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -442,6 +479,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
@ -594,7 +633,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
del layer.w13_blockscale_swizzled
|
||||
del layer.w2_blockscale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -614,6 +661,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
if self.use_marlin:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
|
@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: bool,
|
||||
has_zp: Optional[bool] = None,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
|
||||
if device_capability < 80:
|
||||
return []
|
||||
|
||||
# - has_zp is True: return quant_types that has zero points
|
||||
# - has_zp is False: return quant_types that has not zero points
|
||||
# - has_zp is None: both
|
||||
if has_zp is None:
|
||||
types0 = query_marlin_supported_quant_types(False, include_fp_type,
|
||||
device_capability)
|
||||
types1 = query_marlin_supported_quant_types(True, include_fp_type,
|
||||
device_capability)
|
||||
return types0 + types1
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4]
|
||||
@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn]
|
||||
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
|
||||
return res
|
||||
|
||||
|
||||
@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
|
@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp4_marlin_process_scales(marlin_scales):
|
||||
assert (marlin_scales >= 0).all()
|
||||
|
||||
# convert to half first, we would convert to fp8 later
|
||||
marlin_scales = marlin_scales.to(torch.half)
|
||||
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1)
|
||||
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1)
|
||||
|
||||
# We assume that weight_scale (FP8-S1E4M3) is always greater
|
||||
# than or equal to 0. So we can convert
|
||||
# (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
|
||||
# After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
|
||||
# when weight_scale > 0. This allows us to have an exponent bias
|
||||
# closer to zero after dequantization.
|
||||
|
||||
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
|
||||
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
|
||||
marlin_scales = marlin_scales[:, 1::2].contiguous()
|
||||
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def fp4_marlin_process_global_scale(global_scale):
|
||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
||||
fp4_exponent = 2
|
||||
if global_scale.dtype == torch.half:
|
||||
target_exponent = 5
|
||||
elif global_scale.dtype == torch.bfloat16:
|
||||
target_exponent = 8
|
||||
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
|
||||
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
|
||||
exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
|
||||
return global_scale * (2.0**(exponent_bias - 7))
|
||||
|
||||
|
||||
def apply_fp4_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
# For GPUs that lack FP4 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP4 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||
n=size_n,
|
||||
k=size_k,
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
workspace=workspace,
|
||||
b_q_type=scalar_types.float4_e2m1f,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
param_dtype = layer.params_dtype
|
||||
|
||||
assert layer.weight.shape == (part_size_n, part_size_k // 2)
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
qweight = layer.weight.view(torch.int32).T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=4)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
weight_scale = layer.weight_scale.T.to(param_dtype)
|
||||
weight_scale = marlin_permute_scales(s=weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=16)
|
||||
weight_scale = fp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
param_dtype = layer.params_dtype
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
for name in ["w13_weight", "w2_weight"]:
|
||||
weight = getattr(layer, name)
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
assert weight.shape == (e, size_n, size_k // 2)
|
||||
|
||||
for i in range(e):
|
||||
qweight = weight[i].view(torch.int32).T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
setattr(layer, name, weight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
for name in ["w13", "w2"]:
|
||||
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
|
||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i].T,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=16)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
|
||||
|
||||
def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
|
||||
global_scale = scales.max() / 448
|
||||
scales = (scales / global_scale).to(torch.float8_e4m3fn)
|
||||
|
||||
fp4_weight = torch.randint(0,
|
||||
256, (size_n, size_k // 2),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
|
||||
((fp4_weight & 0b01110000) >> 2))
|
||||
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
|
||||
|
||||
fp4_weight2 = fp4_weight << 4
|
||||
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
|
||||
((fp4_weight2 & 0b01110000) >> 2))
|
||||
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
|
||||
|
||||
weight_ref = torch.cat(
|
||||
[fp4_weight_part_2.unsqueeze(2),
|
||||
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
|
||||
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
|
||||
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp8_fused_exponent_bias_into_scales(scales):
|
||||
fp8_exponent = 4
|
||||
if scales.dtype == torch.half:
|
||||
target_exponent = 5
|
||||
elif scales.dtype == torch.bfloat16:
|
||||
target_exponent = 8
|
||||
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
|
||||
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
|
||||
exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
|
||||
s = torch.ones_like(scales) * 2
|
||||
s = s**exponent_bias
|
||||
return scales * s
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
global_scale=None,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.T.contiguous()
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.T.repeat_interleave(block_n, 1)
|
||||
scales = scales.repeat_interleave(block_n, 1)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[:, :part_size_n]
|
||||
|
||||
@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
|
||||
@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
# block-wise quantization -> group-wise quantization
|
||||
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||
if not size_k_first:
|
||||
scales = scales.permute(0, 2, 1)
|
||||
block_n = layer.weight_block_size[0]
|
||||
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
|
||||
scales = scales.repeat_interleave(block_n, 2)
|
||||
# size_n may not divisible by block_size[0]
|
||||
scales = scales[..., :size_n].contiguous()
|
||||
|
||||
@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales
|
||||
|
Reference in New Issue
Block a user