mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
### What this PR does / why we need it?
Follow vllm-project/vllm lint way:
https://github.com/vllm-project/vllm/blob/main/.pre-commit-config.yaml
Enable pre-commit to avoid some low level error AMAP.
This pr is one step of #1241, The purpose is make linting system more
clear and convenient, on this step, Mainly did the following things:
yapf, actionlint, ruff, typos, isort, mypy, png-lint, signoff-commit,
enforce-import-regex-instead-of-re.
TODO:
- clang-format(check for csrc with google style)
need clean code, disable for now
- pymarkdown
need clean code, disable for now
- shellcheck
need clean code, disable for now
### Does this PR introduce _any_ user-facing change?
Only developer UX change:
https://vllm-ascend--1256.org.readthedocs.build/en/1256/developer_guide/contributing.html#run-lint-locally
```
pip install -r requirements-lint.txt && pre-commit install
bash format.sh
```
### How was this patch tested?
CI passed with new added/existing test.
Co-authored-by: Yikun [yikunkero@gmail.com](mailto:yikunkero@gmail.com)
Co-authored-by: wangli
[wangli858794774@gmail.com](mailto:wangli858794774@gmail.com)
- vLLM version: v0.9.1
- vLLM main:
5358cce5ff
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
380 lines
16 KiB
Python
380 lines
16 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch_npu
|
|
import vllm.envs as vllm_envs
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
|
from vllm.distributed.parallel_state import get_dp_group
|
|
from vllm.logger import logger
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
|
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
|
|
|
|
class SimpleConnector(KVConnectorBase):
|
|
|
|
def __init__(
|
|
self,
|
|
rank: int,
|
|
local_rank: int,
|
|
config: VllmConfig,
|
|
):
|
|
self.config = config
|
|
self.model_config = config.model_config.hf_config
|
|
self.tp_size = config.parallel_config.tensor_parallel_size
|
|
self.rank = rank
|
|
self.local_rank = local_rank
|
|
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
|
self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE
|
|
self.n_layer = self.config.model_config.get_num_layers(
|
|
self.config.parallel_config)
|
|
|
|
self.producer_data_pipe: Optional[SimplePipe]
|
|
self.consumer_data_pipe: Optional[SimplePipe]
|
|
|
|
self.producer_buffer: Optional[SimpleBuffer]
|
|
self.consumer_buffer: Optional[SimpleBuffer]
|
|
|
|
if self.config.kv_transfer_config.is_kv_producer:
|
|
self.producer_data_pipe = SimplePipe(
|
|
rank=rank,
|
|
local_rank=local_rank,
|
|
kv_transfer_config=config.kv_transfer_config,
|
|
hostname="",
|
|
port_offset=rank,
|
|
)
|
|
self.producer_buffer = SimpleBuffer(self.producer_data_pipe)
|
|
else:
|
|
self.consumer_data_pipe = SimplePipe(
|
|
rank=rank,
|
|
local_rank=local_rank,
|
|
kv_transfer_config=config.kv_transfer_config,
|
|
hostname="",
|
|
port_offset=rank,
|
|
)
|
|
self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe)
|
|
|
|
def select(
|
|
self,
|
|
input_tokens: Optional[torch.Tensor],
|
|
roi: Optional[torch.Tensor],
|
|
req_id: str,
|
|
) -> List[Optional[torch.Tensor]]:
|
|
|
|
assert self.consumer_buffer is not None, (
|
|
"Please initialize the "
|
|
"consumer buffer before calling select.")
|
|
return self.consumer_buffer.drop_select(input_tokens, roi, req_id)
|
|
|
|
def insert(
|
|
self,
|
|
input_tokens: torch.Tensor,
|
|
roi: torch.Tensor,
|
|
keys: torch.Tensor,
|
|
values: torch.Tensor,
|
|
hidden: torch.Tensor,
|
|
req_id: str,
|
|
) -> None:
|
|
|
|
assert self.producer_buffer is not None, (
|
|
"Please initialize the "
|
|
"producer buffer before calling insert.")
|
|
self.producer_buffer.insert(input_tokens, roi, keys, values, hidden,
|
|
req_id)
|
|
|
|
def send_kv_caches_and_hidden_states(
|
|
self,
|
|
model_executable: torch.nn.Module,
|
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
kv_caches: List[torch.Tensor],
|
|
hidden_or_intermediate_states: Union[torch.Tensor,
|
|
IntermediateTensors],
|
|
) -> None:
|
|
input_tokens_tensor = model_input.input_tokens
|
|
seq_lens = model_input.attn_metadata.seq_lens
|
|
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
|
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
|
start_layer = model_executable.model.start_layer
|
|
end_layer = model_executable.model.end_layer
|
|
|
|
model_config = self.model_config
|
|
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
|
hidden_size = model_config.hidden_size
|
|
num_attention_heads = model_config.num_attention_heads
|
|
|
|
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
|
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
|
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
|
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
|
# kv_lora_rank + qk_rope_head_dim].
|
|
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
|
# to a kv_cache shape of [2, num_blks, blk_size,
|
|
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
|
# For more details, see vllm/attention/backends/mla/common.py.
|
|
if self.is_deepseek_mla and self.use_mla_opt:
|
|
head_size = (model_config.kv_lora_rank +
|
|
model_config.qk_rope_head_dim)
|
|
num_heads = 1
|
|
elif self.is_deepseek_mla and not self.use_mla_opt:
|
|
head_size = (model_config.qk_nope_head_dim +
|
|
model_config.qk_rope_head_dim)
|
|
else:
|
|
head_size = getattr(
|
|
model_config,
|
|
"head_dim",
|
|
int(hidden_size // num_attention_heads),
|
|
)
|
|
# Enumerate over all requests and insert them one by one.
|
|
for idx, slen in enumerate(seq_lens):
|
|
start_pos = sum(seq_lens[:idx])
|
|
end_pos = start_pos + slen
|
|
|
|
if start_pos >= num_prefill_tokens:
|
|
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
|
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
logger.warning("You have some decode requests while using "
|
|
"SimpleConnector. Their KVCache won't be sent.")
|
|
break
|
|
|
|
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
|
|
|
keys, values = [], []
|
|
|
|
for layer_id in range(start_layer, end_layer):
|
|
kv_cache = kv_caches[layer_id - start_layer]
|
|
|
|
if self.is_deepseek_mla and self.use_mla_opt:
|
|
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
|
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
|
else:
|
|
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
|
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
|
|
|
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
|
|
|
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
|
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
|
|
|
# shape: [num_layers, num_tokens, num_heads, head_size]
|
|
keys = torch.cat(keys, dim=0)
|
|
values = torch.cat(values, dim=0)
|
|
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
|
|
# Currently we haven't considered situation of roi, pass None here.
|
|
self.insert(
|
|
current_tokens,
|
|
None,
|
|
keys,
|
|
values,
|
|
hidden_or_intermediate_states[start_pos:end_pos],
|
|
cur_req_id,
|
|
)
|
|
|
|
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
|
|
|
|
def recv_kv_caches_and_hidden_states(
|
|
self,
|
|
model_executable: torch.nn.Module,
|
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
kv_caches: List[torch.Tensor],
|
|
) -> Tuple[
|
|
Union[torch.Tensor, IntermediateTensors],
|
|
bool,
|
|
"ModelInputForGPUWithSamplingMetadata",
|
|
]:
|
|
bypass_model_exec = True
|
|
|
|
model_config = self.model_config
|
|
|
|
# get model config
|
|
start_layer = model_executable.model.start_layer
|
|
end_layer = model_executable.model.end_layer
|
|
num_heads, head_dim = kv_caches[0].shape[-2:]
|
|
hidden_size = model_config.hidden_size
|
|
num_attention_heads = model_config.num_attention_heads
|
|
num_layers = end_layer - start_layer
|
|
if self.is_deepseek_mla and self.use_mla_opt:
|
|
head_size = (model_config.kv_lora_rank +
|
|
model_config.qk_rope_head_dim)
|
|
num_heads = 1
|
|
elif self.is_deepseek_mla and not self.use_mla_opt:
|
|
head_size = (model_config.qk_nope_head_dim +
|
|
model_config.qk_rope_head_dim)
|
|
else:
|
|
head_size = getattr(
|
|
model_config,
|
|
"head_dim",
|
|
int(hidden_size // num_attention_heads),
|
|
)
|
|
self.consumer_buffer.num_heads = num_heads # type: ignore
|
|
self.consumer_buffer.num_layers = num_layers # type: ignore
|
|
self.consumer_buffer.head_size = head_size # type: ignore
|
|
self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore
|
|
self.consumer_buffer.hidden_size = hidden_size # type: ignore
|
|
|
|
input_tokens_tensor = model_input.input_tokens
|
|
seq_lens = model_input.attn_metadata.seq_lens
|
|
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
|
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
|
|
|
total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens
|
|
hidden_or_intermediate_states_for_one_req = []
|
|
|
|
input_tokens_list = []
|
|
num_computed_tokens_list = []
|
|
start_pos_list = []
|
|
|
|
# enumerate different requests
|
|
for idx, slen in enumerate(seq_lens):
|
|
start_pos = sum(seq_lens[:idx])
|
|
end_pos = start_pos + slen
|
|
|
|
if start_pos >= num_prefill_tokens:
|
|
logger.warning("You should set --enable_chunked_prefill=False "
|
|
"and --max_num_batched_tokens "
|
|
"should be equal to --max_seq_len_to_capture")
|
|
bypass_model_exec = False
|
|
assert start_pos == num_prefill_tokens
|
|
break
|
|
|
|
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
|
num_tokens = slen
|
|
|
|
# collecting data for rebuilding the input
|
|
input_tokens_list.append(current_tokens)
|
|
start_pos_list.append(start_pos)
|
|
|
|
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
|
|
|
|
ret = self.select(
|
|
current_tokens,
|
|
torch.ones_like(current_tokens, dtype=bool),
|
|
cur_req_id,
|
|
)
|
|
if ret[0] is None:
|
|
# didn't find any match.
|
|
bypass_model_exec = False
|
|
num_computed_tokens_list.append(0)
|
|
continue
|
|
|
|
keys: torch.Tensor = ret[0]
|
|
values: torch.Tensor = ret[1]
|
|
hidden: torch.Tensor = ret[2]
|
|
|
|
num_computed_tokens = keys.shape[1]
|
|
num_computed_tokens_list.append(num_computed_tokens)
|
|
|
|
# check if both KV cache and the hidden states are received
|
|
# If not, need to redo the forwarding to compute missing states
|
|
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
|
]):
|
|
bypass_model_exec = False
|
|
|
|
# update the end position based on how many tokens are cached.
|
|
end_pos = start_pos + num_computed_tokens
|
|
|
|
# put received KV caches into paged memory
|
|
for i in range(
|
|
model_executable.model.start_layer,
|
|
model_executable.model.end_layer,
|
|
):
|
|
|
|
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
|
layer = model_executable.model.layers[i]
|
|
|
|
if self.is_deepseek_mla and self.use_mla_opt:
|
|
layer.self_attn.attn = layer.self_attn.mla_attn
|
|
key_cache = kv_cache
|
|
slots = slot_mapping[start_pos:end_pos]
|
|
sliced_key = keys[i - model_executable.model.start_layer]
|
|
torch_npu._npu_reshape_and_cache_siso(key=sliced_key,
|
|
key_cache=key_cache,
|
|
slot_indices=slots)
|
|
else:
|
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
sliced_key = keys[i - model_executable.model.start_layer]
|
|
sliced_value = values[i -
|
|
model_executable.model.start_layer]
|
|
torch_npu._npu_reshape_and_cache(
|
|
key=sliced_key,
|
|
value=sliced_value,
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
slot_indices=slot_mapping[start_pos:end_pos],
|
|
)
|
|
|
|
hidden_or_intermediate_states_for_one_req.append(hidden)
|
|
|
|
if not bypass_model_exec:
|
|
# Some of the KV cache is not retrieved
|
|
# Here we will fall back to normal model forwarding
|
|
# But optionally you can adjust model_input so that you only do
|
|
# prefilling on those tokens that are missing KV caches.
|
|
if get_dp_group().world_size > 1:
|
|
bypass_model_exec = True
|
|
hidden_or_intermediate_states = torch.empty(
|
|
[total_tokens, hidden_size],
|
|
dtype=kv_caches[0].dtype,
|
|
device=kv_caches[0].device)
|
|
logger.warning(
|
|
"[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"[rank%d]: Failed to receive all KVs and hidden "
|
|
"states, redo model forwarding.",
|
|
torch.distributed.get_rank())
|
|
hidden_or_intermediate_states = None
|
|
else:
|
|
logger.debug(
|
|
"[rank%d]: Successfully received all KVs and hidden "
|
|
"states, skip model forwarding.",
|
|
torch.distributed.get_rank(),
|
|
)
|
|
# Can't directly concat here which might cause error when bs = 1.
|
|
# hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device)
|
|
if len(hidden_or_intermediate_states_for_one_req) == 1:
|
|
hidden = hidden_or_intermediate_states_for_one_req[0]
|
|
tmp_indice = torch.tensor([0] * hidden.shape[0],
|
|
dtype=torch.int64).npu()
|
|
hidden_or_intermediate_states = torch.empty_like(hidden)
|
|
torch_npu.scatter_update_(
|
|
hidden_or_intermediate_states,
|
|
tmp_indice,
|
|
hidden,
|
|
axis=-1,
|
|
)
|
|
else:
|
|
hidden_or_intermediate_states = torch.cat(
|
|
hidden_or_intermediate_states_for_one_req, dim=0)
|
|
|
|
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
|
|
|
def close(self):
|
|
self.producer_data_pipe.close() # type: ignore
|
|
self.consumer_data_pipe.close() # type: ignore
|
|
self.producer_buffer.close() # type: ignore
|
|
self.consumer_buffer.close() # type: ignore
|