mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] [Kernel] Fix GPU SEGV occurring in int8 kernels (#9391)
This commit is contained in:
@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type const* scale_ptr, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
scale_type const scale = *scale_ptr;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
|
||||
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
|
||||
scale_type const* scale_ptr, azp_type const* azp_ptr,
|
||||
const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
scale_type const scale = *scale_ptr;
|
||||
azp_type const azp = *azp_ptr;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
auto const val = static_cast<float>(input[i]);
|
||||
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
|
||||
out[token_idx * hidden_size + i] = quant_val;
|
||||
out[i] = quant_val;
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type* scale, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
float absmax_val = 0.0f;
|
||||
float const zero = 0.0f;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
float val = static_cast<float>(input[i]);
|
||||
val = val > zero ? val : -val;
|
||||
absmax_val = val > absmax_val ? val : absmax_val;
|
||||
}
|
||||
@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
float const tmp_scale = 127.0f / block_absmax_val;
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
|
||||
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
|
||||
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type* scale, azp_type* azp, const int hidden_size) {
|
||||
int const token_idx = blockIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
|
||||
// Scan for the min and max value for this token
|
||||
float max_val = std::numeric_limits<float>::min();
|
||||
float min_val = std::numeric_limits<float>::max();
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
auto val = static_cast<float>(input[i]);
|
||||
max_val = std::max(max_val, val);
|
||||
min_val = std::min(min_val, val);
|
||||
}
|
||||
@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
// Quantize the values
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
auto const val = static_cast<float>(input[i]);
|
||||
auto const quant_val =
|
||||
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
|
||||
out[token_idx * hidden_size + i] = quant_val;
|
||||
out[i] = quant_val;
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user