mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[4/N][Refactor] torchair model runner refactor (#2208)
There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow #2203, this is the first PR.
What's this PR do:
create common function `_convert_torch_foramt` for initialize_kv_cache
- vLLM version: v0.10.0
- vLLM main:
14a5d903ab
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@ -20,10 +20,11 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
maybe_converting_weight_acl_format)
|
maybe_converting_weight_acl_format)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
@ -113,3 +114,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
with_prefill, is_torchair_compile, input_ids, positions,
|
with_prefill, is_torchair_compile, input_ids, positions,
|
||||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def _convert_torch_format(self, kv_cache):
|
||||||
|
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||||
|
return kv_cache
|
||||||
|
@ -110,6 +110,9 @@ import vllm_ascend.envs as envs_ascend
|
|||||||
|
|
||||||
if is_310p():
|
if is_310p():
|
||||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||||
|
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
|
||||||
|
else:
|
||||||
|
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -2047,8 +2050,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if isinstance(module,
|
if isinstance(module,
|
||||||
(MergedColumnParallelLinear,
|
(MergedColumnParallelLinear,
|
||||||
QKVParallelLinear, RowParallelLinear)):
|
QKVParallelLinear, RowParallelLinear)):
|
||||||
module.weight.data = torch_npu.npu_format_cast(
|
module.weight.data = self._convert_torch_format(
|
||||||
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
module.weight.data)
|
||||||
if self.drafter:
|
if self.drafter:
|
||||||
logger.info("Loading drafter model...")
|
logger.info("Loading drafter model...")
|
||||||
if isinstance(self.drafter, EagleProposer):
|
if isinstance(self.drafter, EagleProposer):
|
||||||
@ -2133,6 +2136,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
ge_cache=False)
|
ge_cache=False)
|
||||||
return self.torchair_compiled_models[batch_size]
|
return self.torchair_compiled_models[batch_size]
|
||||||
|
|
||||||
|
def _convert_torch_format(self, tensor):
|
||||||
|
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||||
|
return tensor
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@ -2141,9 +2148,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
cache size of each layer
|
cache size of each layer
|
||||||
"""
|
"""
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
import torch_npu
|
|
||||||
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
|
|
||||||
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
|
|
||||||
kv_caches: Dict[str, torch.Tensor] = {}
|
kv_caches: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||||
@ -2202,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache_spec.head_size)
|
kv_cache_spec.head_size)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
if self.model_config.is_deepseek_mla:
|
if self.model_config.is_deepseek_mla:
|
||||||
|
|
||||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
nope_dim = head_size - rope_dim
|
nope_dim = head_size - rope_dim
|
||||||
@ -2218,10 +2221,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
nope_cache = torch.zeros(nope_cache_shape,
|
nope_cache = torch.zeros(nope_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
rope_cache = torch_npu.npu_format_cast(
|
rope_cache = self._convert_torch_format(rope_cache)
|
||||||
rope_cache, acl_format)
|
nope_cache = self._convert_torch_format(nope_cache)
|
||||||
nope_cache = torch_npu.npu_format_cast(
|
|
||||||
nope_cache, acl_format)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||||
@ -2259,8 +2260,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache = torch.zeros(cache_shape,
|
kv_cache = torch.zeros(cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
kv_cache = torch_npu.npu_format_cast(
|
kv_cache = self._convert_torch_format(kv_cache)
|
||||||
kv_cache, acl_format)
|
|
||||||
else:
|
else:
|
||||||
cache_size = math.prod(cache_shape)
|
cache_size = math.prod(cache_shape)
|
||||||
cache_size_aligned = cache_size + alignment
|
cache_size_aligned = cache_size + alignment
|
||||||
|
Reference in New Issue
Block a user