mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it? Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch. The final goal is to remove all the patches and align the code arch to vllm, thus we need to do the following work in next prs. TODO: - [x] remove patch on attention spec - [ ] refactor the kvcache creation logic ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? 1. CI passed with existing test. 2. Test pass with deepseek-v3.2-exp - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: MengqingCao <cmq0113@163.com>
426 lines
18 KiB
Python
426 lines
18 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
|
|
#
|
|
|
|
import copy
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
import vllm.envs as envs_vllm
|
|
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
|
from torch_npu.profiler import dynamic_profile as dp
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
|
init_distributed_environment)
|
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
|
from vllm.logger import logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.tasks import SupportedTask
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
DraftTokenIds, ModelRunnerOutput)
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import init_ascend_config
|
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.utils import (init_ascend_soc_version,
|
|
register_ascend_customop, sleep_mode_enabled,
|
|
try_register_lib)
|
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
|
|
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
|
|
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
|
|
|
|
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
|
|
["torch.npu.current_stream"],
|
|
TorchInGraphFunctionVariable,
|
|
) # noqa: E402
|
|
torch_non_c_binding_in_graph_functions_npu[
|
|
"torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
|
|
torch._dynamo.trace_rules.torch_name_rule_map.append(
|
|
torch_non_c_binding_in_graph_functions_npu) # noqa: E402
|
|
|
|
|
|
class NPUWorker(WorkerBase):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
local_rank: int,
|
|
rank: int,
|
|
distributed_init_method: str,
|
|
is_driver_worker: bool = False,
|
|
# Additional parameters for compatibility with vllm
|
|
**kwargs):
|
|
"""Initialize the worker for Ascend."""
|
|
# register patch for vllm
|
|
from vllm_ascend.utils import adapt_patch
|
|
adapt_patch()
|
|
# Register ops when worker init.
|
|
from vllm_ascend import ops
|
|
ops.register_dummy_fusion_op()
|
|
_register_atb_extensions()
|
|
register_ascend_customop(vllm_config)
|
|
# init ascend config and soc version
|
|
init_ascend_config(vllm_config)
|
|
init_ascend_soc_version()
|
|
use_sparse = False
|
|
if vllm_config.model_config is not None:
|
|
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
|
"index_topk")
|
|
if use_sparse:
|
|
# Direct import instead of using try_register_lib to ensure proper error handling when
|
|
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
|
# yapf: disable
|
|
import custom_ops # type: ignore # noqa
|
|
|
|
# yapf: enable
|
|
logger.info(
|
|
"custom_ops module loaded successfully. Custom operators like "
|
|
"torch.ops.custom.npu_sparse_flash_attention are now available."
|
|
)
|
|
|
|
super().__init__(vllm_config=vllm_config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
is_driver_worker=is_driver_worker)
|
|
|
|
# Try to import mindie_turbo to accelerate vLLM inference.
|
|
try_register_lib(
|
|
"mindie_turbo",
|
|
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
|
)
|
|
if self.cache_config.cache_dtype == "auto":
|
|
self.cache_dtype = self.model_config.dtype
|
|
else:
|
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
self.cache_config.cache_dtype]
|
|
|
|
if self.model_config.trust_remote_code:
|
|
# note: lazy import to avoid importing torch before initializing
|
|
from vllm.utils import init_cached_hf_modules
|
|
init_cached_hf_modules()
|
|
|
|
self.profiler = self._init_profiler()
|
|
if sleep_mode_enabled():
|
|
# Buffers saved before sleep
|
|
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
|
|
|
# FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170
|
|
from vllm.model_executor.layers.linear import \
|
|
WEIGHT_LOADER_V2_SUPPORTED
|
|
if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED:
|
|
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
if not sleep_mode_enabled():
|
|
raise ValueError(
|
|
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
|
|
)
|
|
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
|
# Save the buffers before level 2 sleep
|
|
if level == 2:
|
|
model = self.model_runner.model
|
|
self._sleep_saved_buffers = {
|
|
name: buffer.cpu().clone()
|
|
for name, buffer in model.named_buffers()
|
|
}
|
|
allocator = CaMemAllocator.get_instance()
|
|
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
|
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
|
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
|
used_bytes = total - free_bytes_after_sleep
|
|
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
|
logger.info(
|
|
"Sleep mode freed %.2f GiB memory, "
|
|
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
|
used_bytes / GiB_bytes)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
if not sleep_mode_enabled():
|
|
raise ValueError(
|
|
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
|
|
)
|
|
allocator = CaMemAllocator.get_instance()
|
|
allocator.wake_up(tags=tags)
|
|
|
|
# Restore the buffers after level 2 sleep
|
|
if len(self._sleep_saved_buffers):
|
|
model = self.model_runner.model
|
|
for name, buffer in model.named_buffers():
|
|
if name in self._sleep_saved_buffers:
|
|
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
|
self._sleep_saved_buffers = {}
|
|
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
num_cpu_blocks: int) -> None:
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
def _init_device(self):
|
|
device = torch.device(f"npu:{self.local_rank}")
|
|
NPUPlatform.set_device(device)
|
|
NPUPlatform.empty_cache()
|
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
|
# Initialize the distributed environment.
|
|
self._init_worker_distributed_environment()
|
|
# Set random seed.
|
|
NPUPlatform.seed_everything(self.model_config.seed)
|
|
return device
|
|
|
|
def init_device(self):
|
|
device = self._init_device()
|
|
# Init ModelRunner here, so that we have access to self.device.
|
|
self.model_runner = NPUModelRunner(self.vllm_config, device)
|
|
|
|
def determine_available_memory(self) -> int:
|
|
# Profile the memory usage of the model and get the maximum number of
|
|
# cache blocks that can be allocated with the remaining free memory.
|
|
NPUPlatform.clear_npu_memory()
|
|
|
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
|
# of the model.
|
|
_, total_npu_memory = NPUPlatform.mem_get_info()
|
|
self.model_runner.profile_run()
|
|
|
|
# Calculate the number of blocks that can be allocated with the
|
|
# profiled peak memory.
|
|
free_npu_memory, _ = NPUPlatform.mem_get_info()
|
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
|
# GPU did not change their memory usage during the profiling.
|
|
assert self.init_npu_memory > free_npu_memory, (
|
|
"Error in memory profiling. "
|
|
f"Initial free memory {self.init_npu_memory}, current free memory"
|
|
f" {free_npu_memory}. This happens when the NPU memory was "
|
|
"not properly cleaned up before initializing the vLLM instance.")
|
|
|
|
# Get the peak memory allocation recorded by torch
|
|
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
|
|
# TODO: don`t need impl this func after empty_cache in
|
|
# Worker.determine_num_available_blocks() unified`
|
|
NPUPlatform.empty_cache()
|
|
torch_allocated_bytes = torch_npu.npu.memory_stats(
|
|
)["allocated_bytes.all.current"]
|
|
total_allocated_bytes = torch_npu.npu.mem_get_info(
|
|
)[1] - torch_npu.npu.mem_get_info()[0]
|
|
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
|
if non_torch_allocations > 0:
|
|
peak_memory += non_torch_allocations
|
|
available_kv_cache_memory = int(
|
|
total_npu_memory * self.cache_config.gpu_memory_utilization -
|
|
peak_memory)
|
|
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
|
logger.info(
|
|
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
|
|
)
|
|
return available_kv_cache_memory
|
|
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
|
# enable msMonitor to monitor the performance of vllm-ascend
|
|
if envs_ascend.MSMONITOR_USE_DAEMON:
|
|
dp.step()
|
|
|
|
intermediate_tensors = None
|
|
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
|
if forward_pass and not get_pp_group().is_first_rank:
|
|
intermediate_tensors = IntermediateTensors(
|
|
get_pp_group().recv_tensor_dict(
|
|
all_gather_group=get_tp_group()))
|
|
|
|
output = self.model_runner.execute_model(scheduler_output,
|
|
intermediate_tensors)
|
|
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
|
return output
|
|
|
|
assert isinstance(output, IntermediateTensors)
|
|
parallel_config = self.vllm_config.parallel_config
|
|
assert parallel_config.distributed_executor_backend != (
|
|
"external_launcher") and not get_pp_group().is_last_rank
|
|
|
|
get_pp_group().send_tensor_dict(output.tensors,
|
|
all_gather_group=get_tp_group())
|
|
|
|
kv_connector_output = output.kv_connector_output
|
|
if not kv_connector_output:
|
|
return None
|
|
|
|
# In case of PP with kv transfer, we need to pass through the
|
|
# kv_connector_output
|
|
if (not kv_connector_output.finished_sending
|
|
and not kv_connector_output.finished_recving):
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
output.kv_connector_output = kv_connector_output
|
|
return output
|
|
|
|
def load_model(self) -> None:
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
allocator = CaMemAllocator.get_instance()
|
|
assert allocator.get_current_usage() == 0, (
|
|
"Sleep mode can only be "
|
|
"used for one instance per process.")
|
|
context = allocator.use_memory_pool(tag="weights")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext() # type: ignore
|
|
with context:
|
|
self.model_runner.load_model()
|
|
|
|
def compile_or_warm_up_model(self) -> None:
|
|
# Note: need to adapt for graph mode.
|
|
self.model_runner.eplb_warmup()
|
|
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
|
|
or []).copy()
|
|
if not self.model_config.enforce_eager:
|
|
warmup_sizes = [
|
|
x for x in warmup_sizes if x not in
|
|
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
|
]
|
|
for size in sorted(warmup_sizes, reverse=True):
|
|
logger.info("Compile and warming up model for size %d", size)
|
|
self.model_runner._dummy_run(size)
|
|
if not self.model_config.enforce_eager:
|
|
self.model_runner.capture_model()
|
|
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
|
|
# may cause performance degradation at runtime.
|
|
self._warm_up_atb()
|
|
# Reset the seed to ensure that the random state is not affected by
|
|
# the model initialization and profiling.
|
|
NPUPlatform.seed_everything(self.model_config.seed)
|
|
|
|
def _warm_up_atb(self):
|
|
x = torch.rand((2, 4), dtype=torch.float16).npu()
|
|
weight = torch.rand((2, 4), dtype=torch.float16).npu()
|
|
c = torch.rand((4, 4), dtype=torch.float32).npu()
|
|
torch_npu._npu_matmul_add_fp32(x, weight, c)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model_runner.get_model()
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
return self.model_runner.get_kv_cache_spec()
|
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""Allocate NPU KV cache with the specified kv_cache_config."""
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
allocator = CaMemAllocator.get_instance()
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext() # type: ignore
|
|
with context:
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
|
|
def profile(self, is_start: bool = True):
|
|
if self.profiler is None:
|
|
raise RuntimeError("Profiler is not enabled.")
|
|
if is_start:
|
|
self.profiler.start()
|
|
else:
|
|
self.profiler.stop()
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.model_runner.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.model_runner.remove_lora(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.model_runner.list_loras()
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.model_runner.pin_lora(lora_id)
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
|
self.model_runner._dummy_run(num_tokens=1,
|
|
uniform_decode=True,
|
|
force_attention=force_attention)
|
|
|
|
def _init_worker_distributed_environment(self) -> None:
|
|
"""Initialize the distributed environment."""
|
|
init_distributed_environment(self.parallel_config.world_size,
|
|
self.rank, self.distributed_init_method,
|
|
self.local_rank, "hccl")
|
|
ensure_model_parallel_initialized(
|
|
self.parallel_config.tensor_parallel_size,
|
|
self.parallel_config.pipeline_parallel_size)
|
|
init_ascend_model_parallel(self.parallel_config)
|
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
|
|
def _init_profiler(self):
|
|
# Torch profiler. Enabled and configured through env vars:
|
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
if envs_vllm.VLLM_TORCH_PROFILER_DIR:
|
|
if envs_ascend.MSMONITOR_USE_DAEMON:
|
|
raise RuntimeError(
|
|
"MSMONITOR_USE_DAEMON and VLLM_TORCH_PROFILER_DIR cannot be both set at the same time."
|
|
)
|
|
torch_profiler_trace_dir = envs_vllm.VLLM_TORCH_PROFILER_DIR
|
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
torch_profiler_trace_dir)
|
|
|
|
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
|
export_type=torch_npu.profiler.ExportType.Text,
|
|
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
|
|
msprof_tx=False,
|
|
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
|
l2_cache=False,
|
|
op_attr=False,
|
|
data_simplification=False,
|
|
record_op_args=False,
|
|
gc_detect_threshold=None,
|
|
)
|
|
|
|
return torch_npu.profiler.profile(
|
|
activities=[
|
|
torch_npu.profiler.ProfilerActivity.CPU,
|
|
torch_npu.profiler.ProfilerActivity.NPU,
|
|
],
|
|
with_stack=envs_vllm.VLLM_TORCH_PROFILER_WITH_STACK,
|
|
profile_memory=envs_vllm.\
|
|
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
|
with_modules=False,
|
|
experimental_config=experimental_config,
|
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
|
torch_profiler_trace_dir))
|
|
else:
|
|
return None
|
|
|
|
def get_supported_pooling_tasks(self):
|
|
return self.model_runner.get_supported_pooling_tasks()
|
|
|
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
|
return self.model_runner.get_supported_tasks()
|
|
|
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
return self.model_runner.take_draft_token_ids()
|