mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
output type conversion fix (#27159)
This commit is contained in:
@ -134,10 +134,7 @@ def matmul_kernel_persistent(
|
||||
bias_ptrs = bias_ptr + offs_cn
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
||||
accumulator += bias
|
||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||
c = accumulator.to(tl.float8e4nv)
|
||||
else:
|
||||
c = accumulator.to(tl.float16)
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user