mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] fixes the causal_conv1d_update kernel update non-speculative decoding cases (#24680)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -720,15 +720,15 @@ def _causal_conv1d_update_kernel(
|
||||
# STEP 2: assume state_len > seqlen
|
||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# The conv_state updates works in a sliding window manner,
|
||||
# at each forward pass, the tokens are shift by 1, so we
|
||||
# With speculative decoding, the conv_state updates works in a sliding
|
||||
# window manner, at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
|
||||
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
@ -924,7 +924,10 @@ def causal_conv1d_update(
|
||||
)
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
0) if conv_state_indices is not None else 0
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
if num_accepted_tokens is not None:
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
else:
|
||||
state_len = width - 1
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
|
Reference in New Issue
Block a user