[4/N][Refactor] torchair model runner refactor (#2208)

There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow #2203, this is the first PR.

What's this PR do:

create common function `_convert_torch_foramt`  for initialize_kv_cache


- vLLM version: v0.10.0
- vLLM main:
14a5d903ab

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-08-11 21:39:24 +08:00
committed by GitHub
parent eb43a475f4
commit c8b0f5f799
2 changed files with 18 additions and 13 deletions

View File

@ -20,10 +20,11 @@
from typing import Optional from typing import Optional
import torch import torch
import torch_npu
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
maybe_converting_weight_acl_format) maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@ -113,3 +114,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
with_prefill, is_torchair_compile, input_ids, positions, with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states
def _convert_torch_format(self, kv_cache):
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
return kv_cache

View File

@ -110,6 +110,9 @@ import vllm_ascend.envs as envs_ascend
if is_310p(): if is_310p():
torch_npu.npu.set_compile_mode(jit_compile=False) torch_npu.npu.set_compile_mode(jit_compile=False)
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
else:
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
@dataclass @dataclass
@ -2047,8 +2050,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if isinstance(module, if isinstance(module,
(MergedColumnParallelLinear, (MergedColumnParallelLinear,
QKVParallelLinear, RowParallelLinear)): QKVParallelLinear, RowParallelLinear)):
module.weight.data = torch_npu.npu_format_cast( module.weight.data = self._convert_torch_format(
module.weight.data, ACL_FORMAT_FRACTAL_NZ) module.weight.data)
if self.drafter: if self.drafter:
logger.info("Loading drafter model...") logger.info("Loading drafter model...")
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
@ -2133,6 +2136,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ge_cache=False) ge_cache=False)
return self.torchair_compiled_models[batch_size] return self.torchair_compiled_models[batch_size]
def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
@ -2141,9 +2148,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cache size of each layer cache size of each layer
""" """
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
import torch_npu
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: Dict[str, torch.Tensor] = {}
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
@ -2202,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.head_size) kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
if self.model_config.is_deepseek_mla: if self.model_config.is_deepseek_mla:
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim nope_dim = head_size - rope_dim
@ -2218,10 +2221,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
nope_cache = torch.zeros(nope_cache_shape, nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype, dtype=dtype,
device=self.device) device=self.device)
rope_cache = torch_npu.npu_format_cast( rope_cache = self._convert_torch_format(rope_cache)
rope_cache, acl_format) nope_cache = self._convert_torch_format(nope_cache)
nope_cache = torch_npu.npu_format_cast(
nope_cache, acl_format)
else: else:
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
@ -2259,8 +2260,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache = torch.zeros(cache_shape, kv_cache = torch.zeros(cache_shape,
dtype=dtype, dtype=dtype,
device=self.device) device=self.device)
kv_cache = torch_npu.npu_format_cast( kv_cache = self._convert_torch_format(kv_cache)
kv_cache, acl_format)
else: else:
cache_size = math.prod(cache_shape) cache_size = math.prod(cache_shape)
cache_size_aligned = cache_size + alignment cache_size_aligned = cache_size + alignment