From 84563fc65d938f1b2655eac18164fca5666d3c67 Mon Sep 17 00:00:00 2001 From: Shuqiao Li Date: Fri, 18 Apr 2025 13:11:39 +0800 Subject: [PATCH] Add sleep mode feature for Ascend NPU (#513) ### What this PR does / why we need it? This PR adds sleep mode feature for vllm-ascend, when sleeps, we do mainly two things: - offload model weights - discard kv cache RLHF tools(such as https://github.com/volcengine/verl and https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode to accelerate the training process. This PR may solve #375 and #320 . ### Does this PR introduce _any_ user-facing change? No existing user interfaces changed. Users will have two new methods(`sleep()` and `wake_up()`) to use. ### How was this patch tested? This PR is tested with Qwen/Qwen2.5-0.5B-Instruct. At first, we have free NPU memory M1. After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)` executed, we have free NPU memory M2. M2 < M1. Then we call `llm.sleep(level=1)`, we have free NPU memory M3. We have M3 > M2, M3 is very close to M1. Plus, we have the same output tokens before sleep and after wake up, with the config of `SamplingParams(temperature=0, max_tokens=10)` and with the same input tokens of course. This PR is utilizing the CMake procedure of #371 , thanks a lot. Signed-off-by: Shuqiao Li --- CMakeLists.txt | 2 +- csrc/camem_allocator.cpp | 338 ++++++++++++++++++ setup.py | 4 + tests/singlecard/test_camem.py | 92 +++++ vllm_ascend/__init__.py | 3 - vllm_ascend/device_allocator/__init__.py | 0 vllm_ascend/device_allocator/camem.py | 283 +++++++++++++++ vllm_ascend/envs.py | 4 + .../patch/platform/patch_0_8_4/__init__.py | 4 +- .../platform/patch_0_8_4/patch_config.py | 243 +++++++++++++ vllm_ascend/platform.py | 7 + vllm_ascend/worker/worker.py | 43 ++- vllm_ascend/worker/worker_v1.py | 6 + 13 files changed, 1020 insertions(+), 9 deletions(-) create mode 100644 csrc/camem_allocator.cpp create mode 100644 tests/singlecard/test_camem.py create mode 100644 vllm_ascend/device_allocator/__init__.py create mode 100644 vllm_ascend/device_allocator/camem.py create mode 100644 vllm_ascend/patch/platform/patch_0_8_4/patch_config.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 682b93432..e3bbc10b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ project(vllm_ascend_C) # include(CheckCXXcompilerFlag) # check_cxx_compiler_flag("-std=c++17", COMPILER_SUPPORTS_CXX17) - +set(CMAKE_CXX_STANDARD 17) include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) diff --git a/csrc/camem_allocator.cpp b/csrc/camem_allocator.cpp new file mode 100644 index 000000000..8cba79dc5 --- /dev/null +++ b/csrc/camem_allocator.cpp @@ -0,0 +1,338 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +extern "C" { + +#define PY_SSIZE_T_CLEAN +#include + +#include +#include "acl/acl.h" + +// Global references to Python callables +// NOTE: this is borrowed reference, so we don't need to DECREF them. +// This brings the limitation that the allocator needs to be singleton. +static PyObject* g_python_malloc_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; + + +// --------------------------------------------------------------------------- +// Helper functions: + +void ensure_context(unsigned long long device) { + aclrtContext pctx; + aclrtGetCurrentContext(&pctx); + if (!pctx) { + // Ensure device context. + aclrtCreateContext(&pctx, device); + aclrtSetCurrentContext(pctx); + } +} + +void create_and_map(unsigned long long device, ssize_t size, void* d_mem, + aclrtDrvMemHandle* p_memHandle) { + ensure_context(device); + // Define memory allocation properties + aclrtPhysicalMemProp prop = {}; + prop.handleType = ACL_MEM_HANDLE_TYPE_NONE ; + prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; + prop.memAttr = ACL_HBM_MEM_HUGE; + prop.location.id = device; + prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; + prop.reserve = 0; + + // Allocate memory using aclrtMallocPhysical + aclError error_code = aclrtMallocPhysical(p_memHandle, size, &prop, 0); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return; + } + error_code = aclrtMapMem(d_mem, size, 0, *p_memHandle, 0); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return; + } +} + +void unmap_and_release(unsigned long long device, ssize_t size, + void* d_mem, + aclrtDrvMemHandle* p_memHandle) { + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << + // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; + ensure_context(device); + aclError error_code = aclrtUnmapMem(d_mem); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return; + } + error_code = aclrtFreePhysical(*p_memHandle); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return; + } +} + +PyObject* create_tuple_from_c_integers(unsigned long long a, + unsigned long long b, + unsigned long long c, + unsigned long long d) { + // Create a new tuple of size 4 + PyObject* tuple = PyTuple_New(4); + if (!tuple) { + return NULL; // Return NULL on failure + } + + // Convert integers to Python objects and set them in the tuple + PyTuple_SetItem( + tuple, 0, + PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; // Return the created tuple +} + +// --------------------------------------------------------------------------- +// Our exported C functions that call Python: + +__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, aclrtStream stream) { + ensure_context(device); + + // first allocation, align the size, and reserve an address, and also allocate + // a aclrtDrvMemHandle + + // Define memory allocation properties + aclrtPhysicalMemProp prop = {}; + prop.handleType = ACL_MEM_HANDLE_TYPE_NONE ; + prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; + prop.memAttr = ACL_HBM_MEM_HUGE; + prop.location.id = device; + prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; + prop.reserve = 0; + + // Check if the allocation is supported + size_t granularity; + aclError error_code = aclrtMemGetAllocationGranularity(&prop, + ACL_RT_MEM_ALLOC_GRANULARITY_MINIMUM, + &granularity); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return nullptr; + } + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; + void *d_mem; + error_code = aclrtReserveMemAddress(&d_mem, alignedSize, 0, nullptr, 0); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return nullptr; + } + // allocate the aclrtDrvMemHandle + aclrtDrvMemHandle* p_memHandle = + (aclrtDrvMemHandle*)malloc(sizeof(aclrtDrvMemHandle)); + + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* arg_tuple = create_tuple_from_c_integers( + (unsigned long long)device, (unsigned long long)alignedSize, + (unsigned long long)d_mem, (unsigned long long)p_memHandle); + + // Call g_python_malloc_callback + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + + if (!py_result) { + PyErr_Print(); + PyGILState_Release(gstate); + return nullptr; + } + + PyGILState_Release(gstate); + + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; +} + +__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, aclrtStream stream) { + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* py_ptr = + PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, + &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } + + PyGILState_Release(gstate); + + // recv_size == size + // recv_device == device + + // Free memory + + void *d_mem = (void*)recv_d_mem; + // allocate the aclrtDrvMemHandle + aclrtDrvMemHandle* p_memHandle = + (aclrtDrvMemHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); + + // free address and the handle + aclError error_code = aclrtReleaseMemAddress(d_mem); + if (error_code != 0) { + std::cerr << "acl Error, code: " << error_code << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; + return; + } + free(p_memHandle); +} + +// --------------------------------------------------------------------------- +// Python extension boilerplate: + +// Python-exposed function: init_module(python_malloc, python_free) +static PyObject* py_init_module(PyObject* self, PyObject* args) { + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; +} + +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + void *d_mem_ptr = (void*)recv_d_mem; + aclrtDrvMemHandle* p_memHandle = + (aclrtDrvMemHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + void *d_mem_ptr = (void*)recv_d_mem; + aclrtDrvMemHandle* p_memHandle = + (aclrtDrvMemHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, + "Initialize module with python_malloc and python_free callables."}, + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, + "Create and map memory on the device."}, + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, + METH_VARARGS, "Unmap and release memory on the device."}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef camem_allocator_module = { + PyModuleDef_HEAD_INIT, "camem_allocator", + "CANN-mem-based allocator for NPUPluggableAllocator", -1, module_methods}; + +PyMODINIT_FUNC PyInit_vllm_ascend_C(void) { + // Initialize the module + PyObject* module = PyModule_Create(&camem_allocator_module); + if (!module) { + return NULL; + } + return module; +} +} // extern "C" diff --git a/setup.py b/setup.py index 2b4d03dad..35ebba14e 100644 --- a/setup.py +++ b/setup.py @@ -123,6 +123,10 @@ class cmake_build_ext(build_ext): cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"] # Default dump the compile commands for lsp cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"] + if envs.CXX_COMPILER is not None: + cmake_args += [f"-DCMAKE_CXX_COMPILER={envs.CXX_COMPILER}"] + if envs.C_COMPILER is not None: + cmake_args += [f"-DCMAKE_C_COMPILER={envs.C_COMPILER}"] if envs.VERBOSE: cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] diff --git a/tests/singlecard/test_camem.py b/tests/singlecard/test_camem.py new file mode 100644 index 000000000..7ebb70cf2 --- /dev/null +++ b/tests/singlecard/test_camem.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import torch +from vllm import LLM, SamplingParams +from vllm.utils import GiB_bytes + +from tests.utils import fork_new_process_for_each_test +from vllm_ascend.device_allocator.camem import CaMemAllocator + +try: + import torch_npu # noqa: F401 +except ImportError: + print("Failed to import torch_npu.") + + +@fork_new_process_for_each_test +def test_basic_camem(): + # some tensors from default memory pool + shape = (1024, 1024) + x = torch.empty(shape, device='npu:0') + x.zero_() + + # some tensors from custom memory pool + allocator = CaMemAllocator.get_instance() + with allocator.use_memory_pool(): + # custom memory pool + y = torch.empty(shape, device='npu:0') + y.zero_() + y += 1 + z = torch.empty(shape, device='npu:0') + z.zero_() + z += 2 + + # they can be used together + output = x + y + z + assert torch.allclose(output, torch.ones_like(output) * 3) + + free_bytes = torch_npu.npu.mem_get_info()[0] + allocator.sleep() + free_bytes_after_sleep = torch_npu.npu.mem_get_info()[0] + assert free_bytes_after_sleep > free_bytes + allocator.wake_up() + + # they can be used together + output = x + y + z + assert torch.allclose(output, torch.ones_like(output) * 3) + + +@fork_new_process_for_each_test +def test_end_to_end(): + os.environ["VLLM_USE_V1"] = "0" + free, total = torch_npu.npu.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage, + # which is difficult to measure in the test. therefore, we only + # test sleep level 1 here. + llm.sleep(level=1) + + free_gpu_bytes_after_sleep, total = torch_npu.npu.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + # now the memory usage should be less than the model weights + # (0.5B model, 1GiB weights) + assert used_bytes < 1 * GiB_bytes + + llm.wake_up() + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 926c77ee7..7588e70ed 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -18,9 +18,6 @@ def register(): """Register the NPU platform.""" - # Adapt the global patch here. - from vllm_ascend.utils import adapt_patch - adapt_patch(is_global_patch=True) return "vllm_ascend.platform.NPUPlatform" diff --git a/vllm_ascend/device_allocator/__init__.py b/vllm_ascend/device_allocator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_ascend/device_allocator/camem.py b/vllm_ascend/device_allocator/camem.py new file mode 100644 index 000000000..f65c37001 --- /dev/null +++ b/vllm_ascend/device_allocator/camem.py @@ -0,0 +1,283 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# CANN-mem-based pytorch pluggable allocator to implement sleep mode. +# +import dataclasses +import os +from contextlib import contextmanager +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from acl.rt import memcpy # type: ignore # noqa: F401 +from vllm.logger import logger + +try: + import torch_npu # noqa: F401 +except ImportError: + print("Failed to import torch_npu.") + +from vllm.utils import is_pin_memory_available + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +camem_available = False +try: + from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401 + init_module, python_create_and_map, python_unmap_and_release) + lib_name = find_loaded_library("vllm_ascend_C") + camem_available = True +except ModuleNotFoundError as e: + logger.error("Failed to import vllm_ascend_C:%s", e) + init_module = None + python_create_and_map = None + python_unmap_and_release = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = Tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]] +) -> torch_npu.npu.memory.NPUPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch_npu.npu.memory.NPUPluggableAllocator( + lib_name, 'my_malloc', 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]]): + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch_npu.npu.memory.MemPool(new_alloc._allocator) + with torch_npu.npu.memory.use_mem_pool(mem_pool): + yield mem_pool, new_alloc + + +class CaMemAllocator: + """ + A singleton class that manages a memory pool for CANN tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CaMemAllocator": + """ + CaMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + assert camem_available, "camem allocator is not available" + if CaMemAllocator.instance is None: + CaMemAllocator.instance = CaMemAllocator() + return CaMemAllocator.instance + + def __init__(self): + conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, \ + ("Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates.") + + self.pointer_to_data: Dict[int, AllocationData] = {} + self.current_tag: str = CaMemAllocator.default_tag + self.allocator_and_pools: Dict[str, Any] = {} + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) + return + + def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep( + self, + offload_tags: Optional[Union[Tuple[str, ...], + str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CaMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + + assert isinstance(offload_tags, tuple) + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + ACL_MEMCPY_DEVICE_TO_HOST = 2 + dest_max = cpu_ptr + size_in_bytes * 2 + memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, + ACL_MEMCPY_DEVICE_TO_HOST) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory.""" + for ptr, data in self.pointer_to_data.items(): + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + ACL_MEMCPY_HOST_TO_DEVICE = 1 + dest_max = ptr + size_in_bytes * 2 + memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, + ACL_MEMCPY_HOST_TO_DEVICE) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CaMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 3ddb15ada..2d4cdbcbc 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -56,6 +56,10 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"), "LLMDATADIST_SYNC_CACHE_WAIT_TIME": lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"), + "CXX_COMPILER": + lambda: os.getenv("CXX_COMPILER", None), + "C_COMPILER": + lambda: os.getenv("C_COMPILER", None) } # end-env-vars-definition diff --git a/vllm_ascend/patch/platform/patch_0_8_4/__init__.py b/vllm_ascend/patch/platform/patch_0_8_4/__init__.py index 2ed088b74..d1c5ac2ce 100644 --- a/vllm_ascend/patch/platform/patch_0_8_4/__init__.py +++ b/vllm_ascend/patch/platform/patch_0_8_4/__init__.py @@ -13,4 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# + +import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa \ No newline at end of file diff --git a/vllm_ascend/patch/platform/patch_0_8_4/patch_config.py b/vllm_ascend/patch/platform/patch_0_8_4/patch_config.py new file mode 100644 index 000000000..4a30aaa85 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_0_8_4/patch_config.py @@ -0,0 +1,243 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import warnings +from importlib.util import find_spec +from typing import Any, Final, Literal, Mapping, Optional, Union + +import torch +import vllm.envs as envs +from vllm.config import (HfOverrides, ModelConfig, ModelImpl, PoolerConfig, + TaskOption, _get_and_verify_dtype, + _get_and_verify_max_len, get_min_sliding_window, + get_served_model_name, logger) +from vllm.transformers_utils.config import (ConfigFormat, get_config, + get_hf_image_processor_config, + get_hf_text_config) +from vllm.transformers_utils.utils import maybe_model_redirect + + +def new_init( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + hf_config_path: Optional[str] = None, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + disable_cascade_attn: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, list[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_token: Optional[Union[bool, str]] = None, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: str = "auto", + enable_sleep_mode: bool = False, + override_generation_config: Optional[dict[str, Any]] = None, + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, +) -> None: + self.model = maybe_model_redirect(model) + self.tokenizer = maybe_model_redirect(tokenizer) + + self.hf_config_path = hf_config_path + if isinstance(hf_config_path, str): + self.hf_config_path = maybe_model_redirect(hf_config_path) + + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.allowed_local_media_path = allowed_local_media_path + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self.model_impl = model_impl + + if hf_overrides is None: + hf_overrides = {} + + if callable(hf_overrides): + hf_overrides_kw: dict[str, Any] = {} + hf_overrides_fn = hf_overrides + else: + hf_overrides_kw = hf_overrides + hf_overrides_fn = None + + if rope_scaling is not None: + hf_override: dict[str, Any] = {"rope_scaling": rope_scaling} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides) + msg = ("`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + if rope_theta is not None: + hf_override = {"rope_theta": rope_theta} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides) + msg = ("`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) + + if (backend := envs.VLLM_ATTENTION_BACKEND + ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: + raise ValueError( + "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "module was not found. See " + "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 + "for instructions on how to install it.") + + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.enforce_eager = enforce_eager + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.disable_cascade_attn = disable_cascade_attn + self.skip_tokenizer_init = skip_tokenizer_init + self.enable_sleep_mode = enable_sleep_mode + + from vllm.platforms import current_platform + + hf_config = get_config(self.hf_config_path or self.model, + trust_remote_code, revision, code_revision, + config_format) + + if hf_overrides_kw: + logger.info("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.info("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + + self.hf_config = hf_config + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr(self.hf_text_config, + "attention_chunk_size", None) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=hf_token, revision=revision) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.use_async_output_proc = use_async_output_proc + self.mm_processor_kwargs = mm_processor_kwargs + self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache + + # Set enforce_eager to False if the value is unset. + if self.enforce_eager is None: + self.enforce_eager = False + + interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] + sliding_window = getattr(self.hf_text_config, "sliding_window", None) + has_interleaved_attention = (sliding_window is not None) and ( + isinstance(sliding_window, list) or + (self.hf_text_config.model_type in interleaved_attn_models)) + + if (not self.disable_sliding_window and has_interleaved_attention): + if (backend := + envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + logger.warning_once( + f"{self.hf_text_config.model_type} has interleaved " + "attention, which is currently not supported by the " + f"{backend} backend. Disabling sliding window and capping " + "the max length to the sliding window size " + f"({sliding_window_len_min}).") + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + delattr(self.hf_text_config, "sliding_window") + sliding_window = None + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=spec_target_max_model_len, + encoder_config=self.encoder_config) + self.served_model_name = get_served_model_name(model, served_model_name) + self.multimodal_config = self._init_multimodal_config(limit_mm_per_prompt) + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + + self.is_attention_free = self._init_attention_free() + self.is_hybrid = self._init_is_hybrid() + self.has_noops = self._init_has_noops() + self.has_inner_state = self._init_has_inner_state() + + if current_platform.is_neuron(): + self.override_neuron_config = override_neuron_config + else: + self.override_neuron_config = None + + supported_tasks, task = self._resolve_task(task) + self.supported_tasks = supported_tasks + self.task: Final = task # type: ignore + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + self.pooler_config = self._init_pooler_config(override_pooler_config) + self.logits_processor_pattern = logits_processor_pattern + + self.generation_config = generation_config + self.override_generation_config = override_generation_config or {} + + self._verify_quantization() + self._verify_cuda_graph() + self._verify_bnb_config() + + +# The platform assertion is deleted to support the npu platform. +ModelConfig.__init__ = new_init diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 5a09905d7..fad885e63 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -63,10 +63,17 @@ class NPUPlatform(Platform): supported_quantization: list[str] = ["ascend"] + def is_sleep_mode_available(self) -> bool: + return True + @classmethod def pre_register_and_update(cls, parser: Optional[FlexibleArgumentParser] = None ) -> None: + # Adapt the global patch here. + from vllm_ascend.utils import adapt_patch + adapt_patch(is_global_patch=True) + from vllm_ascend.quantization.quant_config import \ AscendQuantConfig # noqa: F401 diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 595fb4652..f1bf0b843 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -36,13 +36,14 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import bind_kv_cache +from vllm.utils import GiB_bytes, bind_kv_cache from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) +from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib, vllm_version_is from vllm_ascend.worker.model_runner import NPUModelRunner @@ -159,6 +160,24 @@ class NPUWorker(LocalOrDistributedWorkerBase): else: self.profiler = None + def sleep(self, level: int = 1) -> None: + NPUPlatform.set_device(self.device) + free_bytes_before_sleep = NPUPlatform.mem_get_info()[0] + allocator = CaMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = NPUPlatform.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + allocator = CaMemAllocator.get_instance() + allocator.wake_up(tags=tags) + def init_device(self) -> None: if self.device_config.device.type == "npu": self.device = torch.device(f"npu:{self.local_rank}") @@ -176,7 +195,17 @@ class NPUWorker(LocalOrDistributedWorkerBase): set_random_seed(self.model_config.seed) def load_model(self): - self.model_runner.load_model() + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CaMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag="weights") + else: + from contextlib import nullcontext + context = nullcontext() # type: ignore + with context: + self.model_runner.load_model() def start_profile(self): if self.profiler is None: @@ -263,8 +292,14 @@ class NPUWorker(LocalOrDistributedWorkerBase): self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._init_cache_engine() + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CaMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + from contextlib import nullcontext + context = nullcontext() # type: ignore + with context: + self._init_cache_engine() self._warm_up_model() def _init_cache_engine(self): diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index f839518b1..0cdb00a80 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -89,6 +89,12 @@ class NPUWorker(WorkerBase): self.profiler = self._init_profiler() + def sleep(self, level: int = 1) -> None: + logger.error("Sleep mode is only supported on v0") + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + logger.error("Sleep mode is only supported on v0") + def init_device(self): if self.device_config.device.type == "npu": self.device = torch.device(f"npu:{self.local_rank}")