Files
vllm-ascend/vllm_ascend/worker/worker_v1.py
hfadzxy 9935d45728 [CI]Add model basic accuracy test(Qwen2.5-0.5B-Instruct) (#460)
### What this PR does / why we need it?
Add model basic accuracy test(Qwen2.5-0.5B-Instruct)

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
2025-04-17 14:59:56 +08:00

238 lines
9.9 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
#
import gc
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch_npu
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import logger
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class NPUWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
# Additional parameters for compatibility with vllm
**kwargs):
"""Initialize the worker for Ascend."""
# register patch for vllm
from vllm_ascend.utils import adapt_patch
adapt_patch()
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.profiler = self._init_profiler()
def init_device(self):
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
NPUPlatform.set_device(self.device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
else:
info = f"Not support device type: {self.device_config.device}"
logger.error(info)
raise RuntimeError(info)
# Initialize the distributed environment.
self._init_worker_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int:
kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
npu_k_cache = torch.tensor([],
dtype=layer_spec.dtype,
device=self.device)
npu_v_cache = torch.tensor([],
dtype=layer_spec.dtype,
device=self.device)
kv_caches[layer_name] = (npu_k_cache, npu_v_cache)
else:
raise NotImplementedError
runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
NPUPlatform.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_npu_memory - free_npu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_npu_memory}, current free memory"
f" {free_npu_memory}. This happens when the NPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
gc.collect()
# TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified`
NPUPlatform.empty_cache()
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
npu_kv_cache_bytes = max(usable_memory_size, 0)
logger.info(
f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}"
)
return int(npu_kv_cache_bytes)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
def load_model(self) -> None:
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
logger.warning("Graph capture is not supported on NPU.")
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate NPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config)
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(
not self.parallel_config.disable_custom_all_reduce)
init_distributed_environment(self.parallel_config.world_size,
self.rank, self.distributed_init_method,
self.local_rank, "hccl")
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
experimental_config = torch_npu.profiler._ExperimentalConfig(
export_type=torch_npu.profiler.ExportType.Text,
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
msprof_tx=False,
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
l2_cache=False,
op_attr=False,
data_simplification=False,
record_op_args=False,
gc_detect_threshold=None,
)
return torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU,
],
with_stack=True,
profile_memory=True,
with_modules=True,
experimental_config=experimental_config,
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir))
else:
return None