[Bugfix][Kernel] Prevent integer overflow in fp8 dynamic per-token quantize kernel (#9425)

This commit is contained in:
Tyler Michael Smith
2024-10-16 19:46:06 -04:00
committed by GitHub
parent 776dbd74f1
commit c3fab5f769

View File

@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.