mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix __syncwarp
on ROCM (#25996)
This commit is contained in:
@ -536,7 +536,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
// Reduced amax
|
||||
for (int mask = 16; mask > 0; mask /= 2) {
|
||||
@ -546,7 +548,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||
#endif
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
@ -1167,4 +1171,4 @@ void indexer_k_quant_and_cache(
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
|
||||
CALL_INDEXER_K_QUANT_AND_CACHE);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user