[CI] fix pre-commit error (#12494)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-01-28 14:11:05 +08:00
committed by GitHub
parent 0f465ab533
commit dd66fd2b01
2 changed files with 29 additions and 16 deletions

View File

@ -106,11 +106,12 @@ def _flash_attention_core(
assert (continuous_batching_mask
is not None), "continuous_batching_mask input is required."
if continuous_batching_mask is not None:
assert (logit_bias_tile is
None), "continuous_batching_mask does not support logit_bias!"
assert (
logit_bias_tile
is None), "continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arthimetic intensity by half
# which reduce the arithmetic intensity by half
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
LARGE_TILE_SZ if use_causal_mask else None)
@ -468,9 +469,11 @@ def flash_paged_attention(
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
head_id, :])
cur_v_tile[partition_idx,
nl.ds(block_in_partition *
block_size, block_size), :, ] = loaded_v
cur_v_tile[
partition_idx,
nl.ds(block_in_partition * block_size, block_size),
:,
] = loaded_v
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
@ -601,20 +604,30 @@ def flash_paged_attention(
)
nl.store(
o[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), :, ],
o[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
:,
],
out,
)
# maximum and summation statistics
if return_debug_tensors:
nl.store(
hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
hbm_m_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
m_buffer[i, i_q_h, :, :],
)
nl.store(
hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
hbm_l_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
l_buffer[:, i, i_q_h],
)
nl.store(

View File

@ -870,10 +870,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
# Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[
accepted_index != VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[
accepted_index != VLLM_INVALID_TOKEN_ID]
hidden_states = hidden_states[accepted_index !=
VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[accepted_index !=
VLLM_INVALID_TOKEN_ID]
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
index = accepted_index[:, None, None].expand(-1, 1,