[AMD][Kernel][Bugfix] Cast offsets tensor bn to tl.int64 to avoid GPU segfault (#23692)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith
2025-09-02 17:13:57 -05:00
committed by GitHub
parent d328f7894f
commit 457e471971

View File

@ -146,7 +146,7 @@ def _fwd_kernel(Q,
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
(start_n // BLOCK_SIZE) * stride_b_loc_s)
(start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
# [D,BLOCK_SIZE]
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
@ -367,7 +367,7 @@ def _fwd_kernel_flash_attn_v2(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
other=0).to(tl.int64)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
@ -575,7 +575,7 @@ def _fwd_kernel_alibi(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
other=0).to(tl.int64)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +