mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Perf] Cuda Kernel for Per Token Group Quant (#21083)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -245,6 +245,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/quantization/fp8/per_token_group_quant.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
|
@ -297,6 +297,11 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0);
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
|
213
csrc/quantization/fp8/per_token_group_quant.cu
Normal file
213
csrc/quantization/fp8/per_token_group_quant.cu
Normal file
@ -0,0 +1,213 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "../vectorization.cuh"
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include "../../dispatch_utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = 0xffff;
|
||||
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
const T* __restrict__ input, void* __restrict__ output_q,
|
||||
scale_packed_t* __restrict__ output_s, const int group_size,
|
||||
const int num_groups, const int groups_per_block, const float eps,
|
||||
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
|
||||
const int scale_stride = 0) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = float;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
scale_element_t* scale_output;
|
||||
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
const int num_elems_per_pack =
|
||||
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
|
||||
const int row_idx = global_group_id / scale_num_rows_element;
|
||||
const int col_idx_raw = global_group_id % scale_num_rows_element;
|
||||
const int col_idx = col_idx_raw / num_elems_per_pack;
|
||||
const int pack_idx = col_idx_raw % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(col_idx * scale_stride * num_elems_per_pack +
|
||||
row_idx * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
scale_output = output_s + global_group_id;
|
||||
}
|
||||
|
||||
// shared memory to cache each group's data to avoid double DRAM reads.
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
scale_element_t y_s_quant = y_s;
|
||||
|
||||
if (lane_id == 0) {
|
||||
*scale_output = y_s_quant;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double min_8bit, double max_8bit,
|
||||
bool scale_ue8m0 = false) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(output_q.is_contiguous());
|
||||
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
TORCH_CHECK(input.numel() % group_size == 0);
|
||||
TORCH_CHECK(output_s.dim() == 2);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
|
||||
const int scale_num_rows = output_s.size(1);
|
||||
const int scale_stride = output_s.stride(1);
|
||||
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
size_t smem_bytes = \
|
||||
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit, scale_num_rows, scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit, scale_num_rows, scale_stride); \
|
||||
} \
|
||||
} else { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
|
||||
}
|
||||
}));
|
||||
|
||||
#undef LAUNCH_KERNEL
|
||||
}
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
@ -601,6 +601,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
|
||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, "
|
||||
"int group_size, float eps, float fp8_min, float fp8_max, bool "
|
||||
"scale_ue8m0) -> ()");
|
||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
||||
&per_token_group_quant_fp8);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
|
44
tests/kernels/quantization/test_per_token_group_quant.py
Normal file
44
tests/kernels/quantization/test_per_token_group_quant.py
Normal file
@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||
@pytest.mark.parametrize("column_major", [False, True])
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_fp8(shape, column_major: bool,
|
||||
scale_ue8m0: bool, group_size: int):
|
||||
device = "cuda"
|
||||
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = (torch.randn(
|
||||
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
|
||||
|
||||
# cuda path
|
||||
out_q, scale = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
out_q: Optional[torch.Tensor] = None,
|
||||
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
@ -397,8 +398,7 @@ def per_token_group_quant_fp8(
|
||||
if x_q is None:
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
# Allocate the scale tensor in either row- or column-major format.
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device,
|
||||
@ -407,6 +407,15 @@ def per_token_group_quant_fp8(
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
# prefer CUDA kernel if available
|
||||
if current_platform.is_cuda() and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
|
||||
fp8_min, fp8_max, use_ue8m0)
|
||||
return x_q, x_s
|
||||
|
||||
# TRITON FALLBACK
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
@ -423,7 +432,7 @@ def per_token_group_quant_fp8(
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=is_blackwell_deep_gemm_used(),
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
@ -439,7 +448,7 @@ def per_token_group_quant_fp8(
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=is_blackwell_deep_gemm_used(),
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
|
Reference in New Issue
Block a user