[Bugfix] fixes the causal_conv1d_update kernel update non-speculative decoding cases (#24680)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Tao He
2025-09-12 09:16:43 +08:00
committed by GitHub
parent 40b6c9122b
commit 880c741bb6

View File

@ -720,15 +720,15 @@ def _causal_conv1d_update_kernel(
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# The conv_state updates works in a sliding window manner,
# at each forward pass, the tokens are shift by 1, so we
# With speculative decoding, the conv_state updates works in a sliding
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
@ -924,7 +924,10 @@ def causal_conv1d_update(
)
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
state_len = width - 1 + (seqlen - 1) # effective state_len needed
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):