Add sleep mode feature for Ascend NPU (#513)

### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:

- offload model weights
- discard kv cache

RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.

This PR may solve #375 and #320 .

### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.

### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.

At first, we have free NPU memory M1.

After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.

Then we call `llm.sleep(level=1)`, we have free NPU memory M3.

We have M3 > M2, M3 is very close to M1.

Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.


This PR is utilizing the CMake procedure of #371 , thanks a lot.

Signed-off-by: Shuqiao Li <celestialli@outlook.com>
This commit is contained in:
Shuqiao Li
2025-04-18 13:11:39 +08:00
committed by GitHub
parent 42c7fbb10e
commit 84563fc65d
13 changed files with 1020 additions and 9 deletions

View File

@ -3,7 +3,7 @@ project(vllm_ascend_C)
# include(CheckCXXcompilerFlag)
# check_cxx_compiler_flag("-std=c++17", COMPILER_SUPPORTS_CXX17)
set(CMAKE_CXX_STANDARD 17)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

338
csrc/camem_allocator.cpp Normal file
View File

@ -0,0 +1,338 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include "acl/acl.h"
// Global references to Python callables
// NOTE: this is borrowed reference, so we don't need to DECREF them.
// This brings the limitation that the allocator needs to be singleton.
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;
// ---------------------------------------------------------------------------
// Helper functions:
void ensure_context(unsigned long long device) {
aclrtContext pctx;
aclrtGetCurrentContext(&pctx);
if (!pctx) {
// Ensure device context.
aclrtCreateContext(&pctx, device);
aclrtSetCurrentContext(pctx);
}
}
void create_and_map(unsigned long long device, ssize_t size, void* d_mem,
aclrtDrvMemHandle* p_memHandle) {
ensure_context(device);
// Define memory allocation properties
aclrtPhysicalMemProp prop = {};
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE ;
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
prop.memAttr = ACL_HBM_MEM_HUGE;
prop.location.id = device;
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
prop.reserve = 0;
// Allocate memory using aclrtMallocPhysical
aclError error_code = aclrtMallocPhysical(p_memHandle, size, &prop, 0);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
error_code = aclrtMapMem(d_mem, size, 0, *p_memHandle, 0);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
}
void unmap_and_release(unsigned long long device, ssize_t size,
void* d_mem,
aclrtDrvMemHandle* p_memHandle) {
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
aclError error_code = aclrtUnmapMem(d_mem);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
error_code = aclrtFreePhysical(*p_memHandle);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
unsigned long long b,
unsigned long long c,
unsigned long long d) {
// Create a new tuple of size 4
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL; // Return NULL on failure
}
// Convert integers to Python objects and set them in the tuple
PyTuple_SetItem(
tuple, 0,
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.
return tuple; // Return the created tuple
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, aclrtStream stream) {
ensure_context(device);
// first allocation, align the size, and reserve an address, and also allocate
// a aclrtDrvMemHandle
// Define memory allocation properties
aclrtPhysicalMemProp prop = {};
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE ;
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
prop.memAttr = ACL_HBM_MEM_HUGE;
prop.location.id = device;
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
prop.reserve = 0;
// Check if the allocation is supported
size_t granularity;
aclError error_code = aclrtMemGetAllocationGranularity(&prop,
ACL_RT_MEM_ALLOC_GRANULARITY_MINIMUM,
&granularity);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return nullptr;
}
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
void *d_mem;
error_code = aclrtReserveMemAddress(&d_mem, alignedSize, 0, nullptr, 0);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return nullptr;
}
// allocate the aclrtDrvMemHandle
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)malloc(sizeof(aclrtDrvMemHandle));
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
// Call g_python_malloc_callback
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);
if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}
PyGILState_Release(gstate);
// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);
return (void*)d_mem;
}
__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, aclrtStream stream) {
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "ERROR: g_python_free_callback not set.\n";
return;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_ptr =
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}
PyGILState_Release(gstate);
// recv_size == size
// recv_device == device
// Free memory
void *d_mem = (void*)recv_d_mem;
// allocate the aclrtDrvMemHandle
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
// free address and the handle
aclError error_code = aclrtReleaseMemAddress(d_mem);
if (error_code != 0) {
std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl;
return;
}
free(p_memHandle);
}
// ---------------------------------------------------------------------------
// Python extension boilerplate:
// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args) {
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;
Py_RETURN_NONE;
}
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
void *d_mem_ptr = (void*)recv_d_mem;
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
void *d_mem_ptr = (void*)recv_d_mem;
aclrtDrvMemHandle* p_memHandle =
(aclrtDrvMemHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
"Initialize module with python_malloc and python_free callables."},
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
"Create and map memory on the device."},
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
METH_VARARGS, "Unmap and release memory on the device."},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef camem_allocator_module = {
PyModuleDef_HEAD_INIT, "camem_allocator",
"CANN-mem-based allocator for NPUPluggableAllocator", -1, module_methods};
PyMODINIT_FUNC PyInit_vllm_ascend_C(void) {
// Initialize the module
PyObject* module = PyModule_Create(&camem_allocator_module);
if (!module) {
return NULL;
}
return module;
}
} // extern "C"

View File

@ -123,6 +123,10 @@ class cmake_build_ext(build_ext):
cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"]
# Default dump the compile commands for lsp
cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"]
if envs.CXX_COMPILER is not None:
cmake_args += [f"-DCMAKE_CXX_COMPILER={envs.CXX_COMPILER}"]
if envs.C_COMPILER is not None:
cmake_args += [f"-DCMAKE_C_COMPILER={envs.C_COMPILER}"]
if envs.VERBOSE:
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]

View File

@ -0,0 +1,92 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes
from tests.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator
try:
import torch_npu # noqa: F401
except ImportError:
print("Failed to import torch_npu.")
@fork_new_process_for_each_test
def test_basic_camem():
# some tensors from default memory pool
shape = (1024, 1024)
x = torch.empty(shape, device='npu:0')
x.zero_()
# some tensors from custom memory pool
allocator = CaMemAllocator.get_instance()
with allocator.use_memory_pool():
# custom memory pool
y = torch.empty(shape, device='npu:0')
y.zero_()
y += 1
z = torch.empty(shape, device='npu:0')
z.zero_()
z += 2
# they can be used together
output = x + y + z
assert torch.allclose(output, torch.ones_like(output) * 3)
free_bytes = torch_npu.npu.mem_get_info()[0]
allocator.sleep()
free_bytes_after_sleep = torch_npu.npu.mem_get_info()[0]
assert free_bytes_after_sleep > free_bytes
allocator.wake_up()
# they can be used together
output = x + y + z
assert torch.allclose(output, torch.ones_like(output) * 3)
@fork_new_process_for_each_test
def test_end_to_end():
os.environ["VLLM_USE_V1"] = "0"
free, total = torch_npu.npu.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm.sleep(level=1)
free_gpu_bytes_after_sleep, total = torch_npu.npu.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage should be less than the model weights
# (0.5B model, 1GiB weights)
assert used_bytes < 1 * GiB_bytes
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text

View File

@ -18,9 +18,6 @@
def register():
"""Register the NPU platform."""
# Adapt the global patch here.
from vllm_ascend.utils import adapt_patch
adapt_patch(is_global_patch=True)
return "vllm_ascend.platform.NPUPlatform"

View File

View File

@ -0,0 +1,283 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CANN-mem-based pytorch pluggable allocator to implement sleep mode.
#
import dataclasses
import os
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from acl.rt import memcpy # type: ignore # noqa: F401
from vllm.logger import logger
try:
import torch_npu # noqa: F401
except ImportError:
print("Failed to import torch_npu.")
from vllm.utils import is_pin_memory_available
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found_line = None
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found_line = line
break
if found_line is None:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = found_line.index("/")
path = found_line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
return path
camem_available = False
try:
from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401
init_module, python_create_and_map, python_unmap_and_release)
lib_name = find_loaded_library("vllm_ascend_C")
camem_available = True
except ModuleNotFoundError as e:
logger.error("Failed to import vllm_ascend_C:%s", e)
init_module = None
python_create_and_map = None
python_unmap_and_release = None
lib_name = None
libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int]
@dataclasses.dataclass
class AllocationData:
handle: HandleType
tag: str
cpu_backup_tensor: Optional[torch.Tensor] = None
def create_and_map(allocation_handle: HandleType) -> None:
python_create_and_map(*allocation_handle)
def unmap_and_release(allocation_handle: HandleType) -> None:
python_unmap_and_release(*allocation_handle)
def get_pluggable_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]
) -> torch_npu.npu.memory.NPUPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(
lib_name, 'my_malloc', 'my_free')
return new_alloc
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
python_free_func: Callable[[int], tuple[int, int, int, int]]):
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch_npu.npu.memory.MemPool(new_alloc._allocator)
with torch_npu.npu.memory.use_mem_pool(mem_pool):
yield mem_pool, new_alloc
class CaMemAllocator:
"""
A singleton class that manages a memory pool for CANN tensors.
The memory in this pool can be offloaded or discarded when the
allocator sleeps.
Inside the `use_memory_pool(tag)` context, all tensors created will
be allocated in the memory pool, and has the same tag as the
tag passed to the context.
When we call `sleep`, all tensors with the specified tag will be
offloaded to CPU memory, and the rest of the tensors will be discarded.
When we call `wake_up`, all tensors that are previously offloaded
will be loaded back to GPU memory, and the rest of the tensors will
have empty memory.
Why it needs to be a singleton?
When allocated tensors are garbage collected, PyTorch will call
the free callback, which will call the `python_free_callback` method.
The C-extension uses a global variable to store the function of an
instance of this class. If we create multiple instances of this class,
the global variable will be overwritten and the free callback will
not work as expected.
"""
instance = None
default_tag: str = "default"
@staticmethod
def get_instance() -> "CaMemAllocator":
"""
CaMemAllocator is a singleton class.
We cannot call the constructor directly.
Call this method to get the instance.
"""
assert camem_available, "camem allocator is not available"
if CaMemAllocator.instance is None:
CaMemAllocator.instance = CaMemAllocator()
return CaMemAllocator.instance
def __init__(self):
conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
assert "expandable_segments:True" not in conf, \
("Expandable segments are not compatible with memory pool. "
"Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates.")
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CaMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag)
return
def python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
data = self.pointer_to_data.pop(ptr)
if data.cpu_backup_tensor is not None:
data.cpu_backup_tensor = None
return data.handle
def sleep(
self,
offload_tags: Optional[Union[Tuple[str, ...],
str]] = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CaMemAllocator.default_tag, )
elif isinstance(offload_tags, str):
offload_tags = (offload_tags, )
assert isinstance(offload_tags, tuple)
for ptr, data in self.pointer_to_data.items():
handle = data.handle
if data.tag in offload_tags:
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device='cpu',
pin_memory=is_pin_memory_available())
cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_DEVICE_TO_HOST = 2
dest_max = cpu_ptr + size_in_bytes * 2
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_HOST)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory."""
for ptr, data in self.pointer_to_data.items():
if tags is None or data.tag in tags:
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
ACL_MEMCPY_HOST_TO_DEVICE = 1
dest_max = ptr + size_in_bytes * 2
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes,
ACL_MEMCPY_HOST_TO_DEVICE)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: Optional[str] = None):
"""
A context manager to use the memory pool.
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:
tag = CaMemAllocator.default_tag
assert isinstance(tag, str)
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released.
# right now it is fine, because we only use this allocator
# during weight loading and kv cache creation, where we only
# allocate memory.
# TODO: we need to find a way to release the memory,
# i.e. calling torch.cuda.empty_cache()
self.current_tag = old_tag
def get_current_usage(self) -> int:
"""
Get the total number of bytes allocated in the memory pool.
"""
sum_bytes: int = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
sum_bytes += handle[1]
return sum_bytes

View File

@ -56,6 +56,10 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
"CXX_COMPILER":
lambda: os.getenv("CXX_COMPILER", None),
"C_COMPILER":
lambda: os.getenv("C_COMPILER", None)
}
# end-env-vars-definition

View File

@ -13,4 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa

View File

@ -0,0 +1,243 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import warnings
from importlib.util import find_spec
from typing import Any, Final, Literal, Mapping, Optional, Union
import torch
import vllm.envs as envs
from vllm.config import (HfOverrides, ModelConfig, ModelImpl, PoolerConfig,
TaskOption, _get_and_verify_dtype,
_get_and_verify_max_len, get_min_sliding_window,
get_served_model_name, logger)
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.transformers_utils.utils import maybe_model_redirect
def new_init(
self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = maybe_model_redirect(model)
self.tokenizer = maybe_model_redirect(tokenizer)
self.hf_config_path = hf_config_path
if isinstance(hf_config_path, str):
self.hf_config_path = maybe_model_redirect(hf_config_path)
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.model_impl = model_impl
if hf_overrides is None:
hf_overrides = {}
if callable(hf_overrides):
hf_overrides_kw: dict[str, Any] = {}
hf_overrides_fn = hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_fn = None
if rope_scaling is not None:
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
raise ValueError(
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
"module was not found. See "
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it.")
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.disable_cascade_attn = disable_cascade_attn
self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode
from vllm.platforms import current_platform
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
config_format)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.info("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(self.hf_text_config,
"attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=hf_token, revision=revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
self.enforce_eager = False
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in interleaved_attn_models))
if (not self.disable_sliding_window and has_interleaved_attention):
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
logger.warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
f"{backend} backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len,
encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(model, served_model_name)
self.multimodal_config = self._init_multimodal_config(limit_mm_per_prompt)
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self.is_attention_free = self._init_attention_free()
self.is_hybrid = self._init_is_hybrid()
self.has_noops = self._init_has_noops()
self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
self.override_neuron_config = None
supported_tasks, task = self._resolve_task(task)
self.supported_tasks = supported_tasks
self.task: Final = task # type: ignore
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
# The platform assertion is deleted to support the npu platform.
ModelConfig.__init__ = new_init

View File

@ -63,10 +63,17 @@ class NPUPlatform(Platform):
supported_quantization: list[str] = ["ascend"]
def is_sleep_mode_available(self) -> bool:
return True
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
) -> None:
# Adapt the global patch here.
from vllm_ascend.utils import adapt_patch
adapt_patch(is_global_patch=True)
from vllm_ascend.quantization.quant_config import \
AscendQuantConfig # noqa: F401

View File

@ -36,13 +36,14 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import bind_kv_cache
from vllm.utils import GiB_bytes, bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib, vllm_version_is
from vllm_ascend.worker.model_runner import NPUModelRunner
@ -159,6 +160,24 @@ class NPUWorker(LocalOrDistributedWorkerBase):
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
NPUPlatform.set_device(self.device)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
allocator = CaMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)
def init_device(self) -> None:
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
@ -176,7 +195,17 @@ class NPUWorker(LocalOrDistributedWorkerBase):
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self.model_runner.load_model()
def start_profile(self):
if self.profiler is None:
@ -263,8 +292,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._init_cache_engine()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):

View File

@ -89,6 +89,12 @@ class NPUWorker(WorkerBase):
self.profiler = self._init_profiler()
def sleep(self, level: int = 1) -> None:
logger.error("Sleep mode is only supported on v0")
def wake_up(self, tags: Optional[list[str]] = None) -> None:
logger.error("Sleep mode is only supported on v0")
def init_device(self):
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")