Files
vllm-ascend/vllm_ascend/distributed/kv_transfer/simple_connector.py
Li Wang c7446438a9 [1/N][CI] Move linting system to pre-commits hooks (#1256)
### What this PR does / why we need it?

Follow vllm-project/vllm lint way:
https://github.com/vllm-project/vllm/blob/main/.pre-commit-config.yaml

Enable pre-commit to avoid some low level error  AMAP.

This pr is one step of #1241, The purpose is make linting system more
clear and convenient, on this step, Mainly did the following things:
yapf, actionlint, ruff, typos, isort, mypy, png-lint, signoff-commit,
enforce-import-regex-instead-of-re.

TODO: 
- clang-format(check for csrc with google style)
need clean code, disable for now 
- pymarkdown
need clean code, disable for now 
- shellcheck
need clean code, disable for now 

### Does this PR introduce _any_ user-facing change?

Only developer UX change:

https://vllm-ascend--1256.org.readthedocs.build/en/1256/developer_guide/contributing.html#run-lint-locally

```
pip install -r requirements-lint.txt && pre-commit install
bash format.sh
```

### How was this patch tested?

CI passed with new added/existing test.

Co-authored-by: Yikun [yikunkero@gmail.com](mailto:yikunkero@gmail.com)
Co-authored-by: wangli
[wangli858794774@gmail.com](mailto:wangli858794774@gmail.com)
- vLLM version: v0.9.1
- vLLM main:
5358cce5ff

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
2025-07-10 14:17:15 +08:00

380 lines
16 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch_npu
import vllm.envs as vllm_envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config
self.model_config = config.model_config.hf_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.rank = rank
self.local_rank = local_rank
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE
self.n_layer = self.config.model_config.get_num_layers(
self.config.parallel_config)
self.producer_data_pipe: Optional[SimplePipe]
self.consumer_data_pipe: Optional[SimplePipe]
self.producer_buffer: Optional[SimpleBuffer]
self.consumer_buffer: Optional[SimpleBuffer]
if self.config.kv_transfer_config.is_kv_producer:
self.producer_data_pipe = SimplePipe(
rank=rank,
local_rank=local_rank,
kv_transfer_config=config.kv_transfer_config,
hostname="",
port_offset=rank,
)
self.producer_buffer = SimpleBuffer(self.producer_data_pipe)
else:
self.consumer_data_pipe = SimplePipe(
rank=rank,
local_rank=local_rank,
kv_transfer_config=config.kv_transfer_config,
hostname="",
port_offset=rank,
)
self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe)
def select(
self,
input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor],
req_id: str,
) -> List[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, (
"Please initialize the "
"consumer buffer before calling select.")
return self.consumer_buffer.drop_select(input_tokens, roi, req_id)
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
hidden: torch.Tensor,
req_id: str,
) -> None:
assert self.producer_buffer is not None, (
"Please initialize the "
"producer buffer before calling insert.")
self.producer_buffer.insert(input_tokens, roi, keys, values, hidden,
req_id)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = self.model_config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = (model_config.kv_lora_rank +
model_config.qk_rope_head_dim)
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = (model_config.qk_nope_head_dim +
model_config.qk_rope_head_dim)
else:
head_size = getattr(
model_config,
"head_dim",
int(hidden_size // num_attention_heads),
)
# Enumerate over all requests and insert them one by one.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
# shape: [num_layers, num_tokens, num_heads, head_size]
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
# Currently we haven't considered situation of roi, pass None here.
self.insert(
current_tokens,
None,
keys,
values,
hidden_or_intermediate_states[start_pos:end_pos],
cur_req_id,
)
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
) -> Tuple[
Union[torch.Tensor, IntermediateTensors],
bool,
"ModelInputForGPUWithSamplingMetadata",
]:
bypass_model_exec = True
model_config = self.model_config
# get model config
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
num_heads, head_dim = kv_caches[0].shape[-2:]
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
num_layers = end_layer - start_layer
if self.is_deepseek_mla and self.use_mla_opt:
head_size = (model_config.kv_lora_rank +
model_config.qk_rope_head_dim)
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = (model_config.qk_nope_head_dim +
model_config.qk_rope_head_dim)
else:
head_size = getattr(
model_config,
"head_dim",
int(hidden_size // num_attention_heads),
)
self.consumer_buffer.num_heads = num_heads # type: ignore
self.consumer_buffer.num_layers = num_layers # type: ignore
self.consumer_buffer.head_size = head_size # type: ignore
self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore
self.consumer_buffer.hidden_size = hidden_size # type: ignore
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to --max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
ret = self.select(
current_tokens,
torch.ones_like(current_tokens, dtype=bool),
cur_req_id,
)
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
keys: torch.Tensor = ret[0]
values: torch.Tensor = ret[1]
hidden: torch.Tensor = ret[2]
num_computed_tokens = keys.shape[1]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(
model_executable.model.start_layer,
model_executable.model.end_layer,
):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
key_cache = kv_cache
slots = slot_mapping[start_pos:end_pos]
sliced_key = keys[i - model_executable.model.start_layer]
torch_npu._npu_reshape_and_cache_siso(key=sliced_key,
key_cache=key_cache,
slot_indices=slots)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
sliced_key = keys[i - model_executable.model.start_layer]
sliced_value = values[i -
model_executable.model.start_layer]
torch_npu._npu_reshape_and_cache(
key=sliced_key,
value=sliced_value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slot_mapping[start_pos:end_pos],
)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
if get_dp_group().world_size > 1:
bypass_model_exec = True
hidden_or_intermediate_states = torch.empty(
[total_tokens, hidden_size],
dtype=kv_caches[0].dtype,
device=kv_caches[0].device)
logger.warning(
"[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]"
)
else:
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.",
torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.",
torch.distributed.get_rank(),
)
# Can't directly concat here which might cause error when bs = 1.
# hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device)
if len(hidden_or_intermediate_states_for_one_req) == 1:
hidden = hidden_or_intermediate_states_for_one_req[0]
tmp_indice = torch.tensor([0] * hidden.shape[0],
dtype=torch.int64).npu()
hidden_or_intermediate_states = torch.empty_like(hidden)
torch_npu.scatter_update_(
hidden_or_intermediate_states,
tmp_indice,
hidden,
axis=-1,
)
else:
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close() # type: ignore
self.consumer_data_pipe.close() # type: ignore
self.producer_buffer.close() # type: ignore
self.consumer_buffer.close() # type: ignore