[BugFix] [Kernel] Fix GPU SEGV occurring in int8 kernels (#9391)

This commit is contained in:
rasmith
2024-10-16 20:34:06 -05:00
committed by GitHub
parent c3fab5f769
commit 92d86da217

View File

@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) { scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr; scale_type const scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn( out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
} }
} }
@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type const* scale_ptr, azp_type const* azp_ptr, scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) { const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr; scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr; azp_type const azp = *azp_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val; out[i] = quant_val;
} }
} }
@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) { scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f; float absmax_val = 0.0f;
float const zero = 0.0f; float const zero = 0.0f;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]); float val = static_cast<float>(input[i]);
val = val > zero ? val : -val; val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val; absmax_val = val > absmax_val ? val : absmax_val;
} }
@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float const tmp_scale = 127.0f / block_absmax_val; float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn( out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
} }
} }
@ -159,13 +169,17 @@ template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel( __global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) { scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x; int64_t const token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
// Scan for the min and max value for this token // Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min(); float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max(); float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]); auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val); max_val = std::max(max_val, val);
min_val = std::min(min_val, val); min_val = std::min(min_val, val);
} }
@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
// Quantize the values // Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]); auto const val = static_cast<float>(input[i]);
auto const quant_val = auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val; out[i] = quant_val;
} }
} }