mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CPU] Enable shared-memory based pipeline parallel for CPU backend (#21289)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@ -6,6 +6,7 @@ set -ex
|
||||
|
||||
# allow to bind to different cores
|
||||
CORE_RANGE=${CORE_RANGE:-48-95}
|
||||
# used for TP/PP E2E test
|
||||
OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95}
|
||||
NUMA_NODE=${NUMA_NODE:-1}
|
||||
|
||||
@ -24,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
@ -78,17 +79,16 @@ function cpu_tests() {
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# online serving
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c '
|
||||
set -e
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
|
||||
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model facebook/opt-125m \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
--endpoint /v1/completions'
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024)
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
|
||||
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
|
||||
#define MIN_THREAD_PROCESS_SIZE (256)
|
||||
@ -34,9 +34,10 @@ struct KernelVecType<c10::Half> {
|
||||
};
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile char _curr_thread_stamp;
|
||||
volatile char _ready_thread_stamp;
|
||||
char _padding1[6];
|
||||
volatile char _curr_thread_stamp[2];
|
||||
volatile char _ready_thread_stamp[2];
|
||||
int local_stamp_buffer_idx;
|
||||
int remote_stamp_buffer_idx;
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
@ -45,23 +46,28 @@ struct ThreadSHMContext {
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
size_t _thread_buffer_mask;
|
||||
char _padding2[56];
|
||||
size_t _thread_buffer_mask[2];
|
||||
char _padding2[40];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: _curr_thread_stamp(1),
|
||||
_ready_thread_stamp(0),
|
||||
: local_stamp_buffer_idx(0),
|
||||
remote_stamp_buffer_idx(0),
|
||||
thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0),
|
||||
_thread_buffer_mask(0) {
|
||||
_spinning_count(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||
_curr_thread_stamp[0] = 1;
|
||||
_curr_thread_stamp[1] = 1;
|
||||
_ready_thread_stamp[0] = 0;
|
||||
_ready_thread_stamp[1] = 0;
|
||||
_thread_buffer_mask[0] = 0;
|
||||
_thread_buffer_mask[1] = 0;
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
@ -70,6 +76,11 @@ struct ThreadSHMContext {
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
|
||||
void set_stamp_buffer_idx(int local, int remote) {
|
||||
local_stamp_buffer_idx = local;
|
||||
remote_stamp_buffer_idx = remote;
|
||||
}
|
||||
|
||||
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
||||
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK(ptr);
|
||||
@ -84,23 +95,27 @@ struct ThreadSHMContext {
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(
|
||||
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask));
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET &
|
||||
_thread_buffer_mask[local_stamp_buffer_idx]));
|
||||
}
|
||||
|
||||
void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; }
|
||||
void next_buffer() {
|
||||
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
char get_curr_stamp() const { return _curr_thread_stamp; }
|
||||
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
|
||||
|
||||
char get_ready_stamp() const { return _ready_thread_stamp; }
|
||||
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
|
||||
|
||||
void next_stamp() {
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp += 1;
|
||||
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp = _curr_thread_stamp;
|
||||
_ready_thread_stamp[local_stamp_buffer_idx] =
|
||||
_curr_thread_stamp[local_stamp_buffer_idx];
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
@ -117,10 +132,11 @@ struct ThreadSHMContext {
|
||||
void wait_for_one(int rank, Cond&& cond) {
|
||||
ThreadSHMContext* rank_ctx = shm_contexts[rank];
|
||||
for (;;) {
|
||||
char local_curr_stamp = get_curr_stamp();
|
||||
char local_ready_stamp = get_ready_stamp();
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp();
|
||||
char rank_ready_stamp = rank_ctx->get_ready_stamp();
|
||||
char local_curr_stamp = get_curr_stamp(local_stamp_buffer_idx);
|
||||
char local_ready_stamp = get_ready_stamp(local_stamp_buffer_idx);
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp(remote_stamp_buffer_idx);
|
||||
char rank_ready_stamp =
|
||||
rank_ctx->get_ready_stamp(remote_stamp_buffer_idx);
|
||||
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
|
||||
rank_ready_stamp)) {
|
||||
break;
|
||||
@ -361,6 +377,15 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reset_threads_stamp_buffer_idx(ThreadSHMContext* ctx, int local,
|
||||
int remote) {
|
||||
int thread_num = ctx->thread_num;
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
thread_ctx->set_stamp_buffer_idx(local, remote);
|
||||
}
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
namespace shm_cc_ops {
|
||||
@ -632,6 +657,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
|
||||
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
||||
metadata->bind_tensor_list(tensor_list_with_metadata);
|
||||
|
||||
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 0, 1);
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
@ -659,6 +685,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 1, 0);
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
@ -677,7 +704,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
thread_ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
|
@ -166,6 +166,20 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe
|
||||
|
||||
- This value is 4GB by default. Larger space can support more concurrent requests, longer context length. However, users should take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, the TP worker will be killed with `exitcode 9` due to out-of-memory.
|
||||
|
||||
### How to do performance tuning for vLLM CPU?
|
||||
|
||||
- First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
|
||||
|
||||
- Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
|
||||
- Offline Inference: `4096 * world_size`
|
||||
- Online Serving: `2048 * world_size`
|
||||
- `--max-num-seqs`, defines the limit of sequence numbers in a single batch, has more impacts on the output token performance.
|
||||
- Offline Inference: `256 * world_size`
|
||||
- Online Serving: `128 * world_size`
|
||||
|
||||
- vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes.
|
||||
|
||||
### Which quantization configs does vLLM CPU support?
|
||||
|
||||
- vLLM CPU supports quantizations:
|
||||
|
@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.utils import pickle
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
@ -26,7 +27,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
if (current_platform.get_cpu_architecture()
|
||||
== CpuArchEnum.X86) and hasattr(
|
||||
torch.ops._C,
|
||||
"init_shm_manager") and unique_name.startswith("tp"):
|
||||
"init_shm_manager") and (unique_name.startswith("tp")
|
||||
or unique_name.startswith("pp")):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
@ -94,6 +96,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: int,
|
||||
) -> None:
|
||||
return self.dist_module.send_tensor_dict(tensor_dict, dst)
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
@ -143,3 +158,44 @@ class _CPUSHMDistributed:
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: int,
|
||||
) -> None:
|
||||
key_list = list(tensor_dict.keys())
|
||||
value_list = list(tensor_dict.values())
|
||||
size_list = []
|
||||
for v in value_list:
|
||||
if not isinstance(v, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
"CpuCommunicator only supports sending tensors.")
|
||||
size_list.append(v.size())
|
||||
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
|
||||
dtype=torch.uint8)
|
||||
value_list.append(key_size_tensor)
|
||||
|
||||
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
|
||||
|
||||
return None
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
|
||||
|
||||
value_list: list[torch.Tensor] = tensor_list[:-1]
|
||||
key_size_tensor = tensor_list[-1]
|
||||
|
||||
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
|
||||
key_list = key_size[0]
|
||||
size_list = key_size[1]
|
||||
assert len(key_list) == len(size_list)
|
||||
assert len(key_list) == len(value_list)
|
||||
|
||||
tensor_dict: dict[str, torch.Tensor] = {}
|
||||
for key, size, t in zip(key_list, size_list, value_list):
|
||||
tensor_dict[key] = t.view(size)
|
||||
return tensor_dict
|
||||
|
@ -272,6 +272,9 @@ class GroupCoordinator:
|
||||
self.use_custom_op_call = (current_platform.is_cuda_alike()
|
||||
or current_platform.is_tpu())
|
||||
|
||||
self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr(
|
||||
torch.ops._C, "init_shm_manager"))
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
"""Return the global rank of the first process in the group"""
|
||||
@ -663,6 +666,11 @@ class GroupCoordinator:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst)
|
||||
return None
|
||||
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
@ -718,6 +726,10 @@ class GroupCoordinator:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src)
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
|
@ -1639,13 +1639,14 @@ class EngineArgs:
|
||||
|
||||
# cpu specific default values.
|
||||
if current_platform.is_cpu():
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 4096,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
UsageContext.LLM_CLASS: 4096 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 128,
|
||||
UsageContext.OPENAI_API_SERVER: 32,
|
||||
UsageContext.LLM_CLASS: 256 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 128 * world_size,
|
||||
}
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
|
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
||||
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None
|
||||
VLLM_CPU_MOE_PREPACK: bool = True
|
||||
@ -430,9 +430,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
||||
|
||||
# (CPU backend only) CPU key-value cache space.
|
||||
# default is 4 GiB
|
||||
# default is None and will be set as 4 GB
|
||||
"VLLM_CPU_KVCACHE_SPACE":
|
||||
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
|
||||
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0"))
|
||||
if "VLLM_CPU_KVCACHE_SPACE" in os.environ else None,
|
||||
|
||||
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
|
||||
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
|
||||
|
@ -104,8 +104,19 @@ class CpuPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
import psutil
|
||||
return psutil.virtual_memory().total
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||
if kv_cache_space is None:
|
||||
kv_cache_space = 4 * GiB_bytes # type: ignore
|
||||
logger.warning_once(
|
||||
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
|
||||
"for CPU backend is not set, using 4 by default.")
|
||||
else:
|
||||
kv_cache_space *= GiB_bytes
|
||||
|
||||
return kv_cache_space
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
@ -124,8 +135,6 @@ class CpuPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import GiB_bytes
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if model_config is not None:
|
||||
@ -162,20 +171,8 @@ class CpuPlatform(Platform):
|
||||
" support fp16 for now, cast to bf16.")
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||
|
||||
if kv_cache_space >= 0:
|
||||
if kv_cache_space == 0:
|
||||
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
||||
logger.warning(
|
||||
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
|
||||
"for CPU backend is not set, using 4 by default.")
|
||||
else:
|
||||
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
cache_config.cpu_kvcache_space_bytes = \
|
||||
CpuPlatform.get_device_total_memory()
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.world_size > 1
|
||||
@ -216,8 +213,6 @@ class CpuPlatform(Platform):
|
||||
False,
|
||||
"nan_asserts":
|
||||
False,
|
||||
"memory_planning":
|
||||
True,
|
||||
"epilogue_fusion":
|
||||
True,
|
||||
})
|
||||
|
Reference in New Issue
Block a user