mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it? This PR adds sleep mode feature for vllm-ascend, when sleeps, we do mainly two things: - offload model weights - discard kv cache RLHF tools(such as https://github.com/volcengine/verl and https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode to accelerate the training process. This PR may solve #375 and #320 . ### Does this PR introduce _any_ user-facing change? No existing user interfaces changed. Users will have two new methods(`sleep()` and `wake_up()`) to use. ### How was this patch tested? This PR is tested with Qwen/Qwen2.5-0.5B-Instruct. At first, we have free NPU memory M1. After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)` executed, we have free NPU memory M2. M2 < M1. Then we call `llm.sleep(level=1)`, we have free NPU memory M3. We have M3 > M2, M3 is very close to M1. Plus, we have the same output tokens before sleep and after wake up, with the config of `SamplingParams(temperature=0, max_tokens=10)` and with the same input tokens of course. This PR is utilizing the CMake procedure of #371 , thanks a lot. Signed-off-by: Shuqiao Li <celestialli@outlook.com>
194 lines
7.0 KiB
Python
194 lines
7.0 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
|
|
|
import torch
|
|
import torch_npu # noqa: F401
|
|
import vllm.envs as envs
|
|
from vllm.logger import logger
|
|
from vllm.platforms import Platform, PlatformEnum
|
|
|
|
CUSTOM_OP_ENABLED = False
|
|
try:
|
|
# register custom ops into torch_library here
|
|
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
|
|
|
except ImportError as e:
|
|
if not str(
|
|
e
|
|
) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)":
|
|
logging.warning(
|
|
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
|
)
|
|
else:
|
|
CUSTOM_OP_ENABLED = True
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import ModelConfig, VllmConfig
|
|
from vllm.utils import FlexibleArgumentParser
|
|
else:
|
|
ModelConfig = None
|
|
VllmConfig = None
|
|
FlexibleArgumentParser = None
|
|
|
|
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
|
|
|
|
|
|
class NPUPlatform(Platform):
|
|
|
|
_enum = PlatformEnum.OOT
|
|
device_name: str = "npu"
|
|
device_type: str = "npu"
|
|
simple_compile_backend: str = "eager" # Disable torch.compile()
|
|
ray_device_key: str = "NPU"
|
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
|
dispatch_key: str = "PrivateUse1"
|
|
|
|
supported_quantization: list[str] = ["ascend"]
|
|
|
|
def is_sleep_mode_available(self) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def pre_register_and_update(cls,
|
|
parser: Optional[FlexibleArgumentParser] = None
|
|
) -> None:
|
|
# Adapt the global patch here.
|
|
from vllm_ascend.utils import adapt_patch
|
|
adapt_patch(is_global_patch=True)
|
|
|
|
from vllm_ascend.quantization.quant_config import \
|
|
AscendQuantConfig # noqa: F401
|
|
|
|
@classmethod
|
|
def get_device_capability(cls, device_id: int = 0):
|
|
return None
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
return torch.npu.get_device_name(device_id)
|
|
|
|
@classmethod
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def inference_mode(cls):
|
|
return torch.inference_mode()
|
|
|
|
@classmethod
|
|
def set_device(cls, device: torch.device):
|
|
torch.npu.set_device(device)
|
|
|
|
@classmethod
|
|
def empty_cache(cls):
|
|
torch.npu.empty_cache()
|
|
|
|
@classmethod
|
|
def synchronize(cls):
|
|
torch.npu.synchronize()
|
|
|
|
@classmethod
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
|
return torch.npu.mem_get_info()
|
|
|
|
@classmethod
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
from vllm.config import CompilationLevel # noqa: E402
|
|
compilation_config = vllm_config.compilation_config
|
|
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
|
|
logger.warning(
|
|
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
|
|
compilation_config.level)
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
if parallel_config and parallel_config.worker_cls == "auto":
|
|
if envs.VLLM_USE_V1:
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
|
elif vllm_config.speculative_config:
|
|
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
|
parallel_config.sd_worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
|
elif vllm_config.scheduler_config.is_multi_step:
|
|
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
|
else:
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
|
|
|
cache_config = vllm_config.cache_config
|
|
if cache_config:
|
|
if cache_config.block_size is None:
|
|
cache_config.block_size = 128
|
|
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
|
|
logger.warning(
|
|
"Prefix caching is not supported for V1 now, disable prefix caching"
|
|
)
|
|
cache_config.enable_prefix_caching = False
|
|
|
|
if envs.VLLM_USE_V1:
|
|
# Activate custom ops for v1.
|
|
vllm_config.compilation_config.custom_ops = ["all"]
|
|
additional_config = vllm_config.additional_config
|
|
# If ascend_scheduler_config exists in additional_config,
|
|
# extents original scheduler_config to use AscendScheduler.
|
|
if additional_config and additional_config.get(
|
|
"ascend_scheduler_config", None) is not None:
|
|
additional_scheduler_config = additional_config.get(
|
|
"ascend_scheduler_config")
|
|
from vllm_ascend.core.schedule_config import \
|
|
AscendSchedulerConfig
|
|
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
|
|
vllm_config.scheduler_config, additional_scheduler_config)
|
|
vllm_config.scheduler_config = ascend_scheduler_config
|
|
|
|
@classmethod
|
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
|
kv_cache_dtype, block_size, use_v1, use_mla):
|
|
if use_v1:
|
|
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
|
if use_mla:
|
|
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
|
|
return "vllm_ascend.attention.attention.AscendAttentionBackend"
|
|
|
|
@classmethod
|
|
def get_punica_wrapper(cls) -> str:
|
|
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
|
|
|
@classmethod
|
|
def get_current_memory_usage(cls,
|
|
device: Optional[torch.types.Device] = None
|
|
) -> float:
|
|
torch.npu.reset_peak_memory_stats(device)
|
|
return torch.npu.max_memory_allocated(device)
|
|
|
|
@classmethod
|
|
def get_device_communicator_cls(cls) -> str:
|
|
return "vllm_ascend.distributed.communicator.NPUCommunicator"
|
|
|
|
@classmethod
|
|
def is_pin_memory_available(cls):
|
|
return True
|
|
|
|
@classmethod
|
|
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
|
"""Returns whether the current platform can support v1 for the supplied
|
|
model configuration.
|
|
"""
|
|
return True
|