Use unsigned int for bias pipelining hack in Triton FlashAttention. A bitwise AND of 0xFFFFFFFF with a negative integer leads to overflow, which got caught by the latest Triton integration.

PiperOrigin-RevId: 739868199
Change-Id: I732e36901c18e4a07c8907c52412b7475d2ebb64
This commit is contained in:
Aliia Khasanova
2025-03-24 03:01:18 -07:00
committed by Copybara-Service
parent 8151373a6b
commit 69b749cb71

View File

@ -83,7 +83,7 @@ def _fwd_kernel_inner(
# Prevent dot accumulating into the bias tensor. It appears that Triton
# doesn't pipeline the bias load as it does the `k` load, so the bias load
# blocks the matmul if the add is merged.
qk = qk.to(tl.int32, bitcast=True) & 0xFFFFFFFF
qk = qk.to(tl.uint32, bitcast=True) & 0xFFFFFFFF
qk = qk.to(tl.float32, bitcast=True)
qk += bias