mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[4/N][Refactor] torchair model runner refactor (#2208)
There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow #2203, this is the first PR.
What's this PR do:
create common function `_convert_torch_foramt` for initialize_kv_cache
- vLLM version: v0.10.0
- vLLM main:
14a5d903ab
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@ -20,10 +20,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
maybe_converting_weight_acl_format)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
@ -113,3 +114,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def _convert_torch_format(self, kv_cache):
|
||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||
return kv_cache
|
||||
|
@ -110,6 +110,9 @@ import vllm_ascend.envs as envs_ascend
|
||||
|
||||
if is_310p():
|
||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
|
||||
else:
|
||||
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -2047,8 +2050,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if isinstance(module,
|
||||
(MergedColumnParallelLinear,
|
||||
QKVParallelLinear, RowParallelLinear)):
|
||||
module.weight.data = torch_npu.npu_format_cast(
|
||||
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
module.weight.data = self._convert_torch_format(
|
||||
module.weight.data)
|
||||
if self.drafter:
|
||||
logger.info("Loading drafter model...")
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
@ -2133,6 +2136,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def _convert_torch_format(self, tensor):
|
||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||
return tensor
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -2141,9 +2148,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cache size of each layer
|
||||
"""
|
||||
self.kv_cache_config = kv_cache_config
|
||||
import torch_npu
|
||||
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
|
||||
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
@ -2202,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
if self.model_config.is_deepseek_mla:
|
||||
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
@ -2218,10 +2221,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch_npu.npu_format_cast(
|
||||
rope_cache, acl_format)
|
||||
nope_cache = torch_npu.npu_format_cast(
|
||||
nope_cache, acl_format)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
@ -2259,8 +2260,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache = torch.zeros(cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_cache = torch_npu.npu_format_cast(
|
||||
kv_cache, acl_format)
|
||||
kv_cache = self._convert_torch_format(kv_cache)
|
||||
else:
|
||||
cache_size = math.prod(cache_shape)
|
||||
cache_size_aligned = cache_size + alignment
|
||||
|
Reference in New Issue
Block a user