output type conversion fix (#27159)

This commit is contained in:
Jianyu Huang
2025-10-19 01:10:07 -07:00
committed by GitHub
parent b3aba04e5a
commit 221bf72577

View File

@ -134,10 +134,7 @@ def matmul_kernel_persistent(
bias_ptrs = bias_ptr + offs_cn
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
accumulator += bias
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(c_ptrs, c, mask=c_mask)