mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI] fix pre-commit error (#12494)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@ -106,11 +106,12 @@ def _flash_attention_core(
|
||||
assert (continuous_batching_mask
|
||||
is not None), "continuous_batching_mask input is required."
|
||||
if continuous_batching_mask is not None:
|
||||
assert (logit_bias_tile is
|
||||
None), "continuous_batching_mask does not support logit_bias!"
|
||||
assert (
|
||||
logit_bias_tile
|
||||
is None), "continuous_batching_mask does not support logit_bias!"
|
||||
|
||||
# mask are used to only apply computation to the lower half of the matrix,
|
||||
# which reduce the arthimetic intensity by half
|
||||
# which reduce the arithmetic intensity by half
|
||||
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
|
||||
LARGE_TILE_SZ if use_causal_mask else None)
|
||||
|
||||
@ -468,9 +469,11 @@ def flash_paged_attention(
|
||||
block_in_partition)
|
||||
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
|
||||
head_id, :])
|
||||
cur_v_tile[partition_idx,
|
||||
nl.ds(block_in_partition *
|
||||
block_size, block_size), :, ] = loaded_v
|
||||
cur_v_tile[
|
||||
partition_idx,
|
||||
nl.ds(block_in_partition * block_size, block_size),
|
||||
:,
|
||||
] = loaded_v
|
||||
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=mask.dtype)
|
||||
@ -601,20 +604,30 @@ def flash_paged_attention(
|
||||
)
|
||||
|
||||
nl.store(
|
||||
o[batch_id, head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE), :, ],
|
||||
o[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
:,
|
||||
],
|
||||
out,
|
||||
)
|
||||
# maximum and summation statistics
|
||||
if return_debug_tensors:
|
||||
nl.store(
|
||||
hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
|
||||
hbm_m_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
m_buffer[i, i_q_h, :, :],
|
||||
)
|
||||
nl.store(
|
||||
hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
|
||||
hbm_l_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
l_buffer[:, i, i_q_h],
|
||||
)
|
||||
nl.store(
|
||||
|
@ -870,10 +870,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||
# Drop non-terminal prefill chunks hidden states.
|
||||
hidden_states = hidden_states[
|
||||
accepted_index != VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[
|
||||
accepted_index != VLLM_INVALID_TOKEN_ID]
|
||||
hidden_states = hidden_states[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||
terminal_metadata)
|
||||
index = accepted_index[:, None, None].expand(-1, 1,
|
||||
|
Reference in New Issue
Block a user