[CPU] Enable shared-memory based pipeline parallel for CPU backend (#21289)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-07-22 00:07:08 +08:00
committed by GitHub
parent 6dda13c86b
commit a15a50fc17
8 changed files with 165 additions and 59 deletions

View File

@ -6,6 +6,7 @@ set -ex
# allow to bind to different cores
CORE_RANGE=${CORE_RANGE:-48-95}
# used for TP/PP E2E test
OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95}
NUMA_NODE=${NUMA_NODE:-1}
@ -24,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
function cpu_tests() {
set -e
@ -78,17 +79,16 @@ function cpu_tests() {
# tests/quantization/test_ipex_quant.py"
# online serving
docker exec cpu-test-"$NUMA_NODE" bash -c "
docker exec cpu-test-"$NUMA_NODE" bash -c '
set -e
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--model meta-llama/Llama-3.2-3B-Instruct \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
--endpoint /v1/completions'
# Run multi-lora tests
docker exec cpu-test-"$NUMA_NODE" bash -c "

View File

@ -7,7 +7,7 @@
namespace {
#define MAX_SHM_RANK_NUM 8
#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024)
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
#define MIN_THREAD_PROCESS_SIZE (256)
@ -34,9 +34,10 @@ struct KernelVecType<c10::Half> {
};
struct ThreadSHMContext {
volatile char _curr_thread_stamp;
volatile char _ready_thread_stamp;
char _padding1[6];
volatile char _curr_thread_stamp[2];
volatile char _ready_thread_stamp[2];
int local_stamp_buffer_idx;
int remote_stamp_buffer_idx;
int thread_id;
int thread_num;
int rank;
@ -45,23 +46,28 @@ struct ThreadSHMContext {
int swizzled_ranks[MAX_SHM_RANK_NUM];
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
size_t _thread_buffer_mask;
char _padding2[56];
size_t _thread_buffer_mask[2];
char _padding2[40];
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
const int group_size, void* thread_shm_ptr)
: _curr_thread_stamp(1),
_ready_thread_stamp(0),
: local_stamp_buffer_idx(0),
remote_stamp_buffer_idx(0),
thread_id(thread_id),
thread_num(thread_num),
rank(rank),
group_size(group_size),
_spinning_count(0),
_thread_buffer_mask(0) {
_spinning_count(0) {
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
TORCH_CHECK((size_t)this % 64 == 0);
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
_curr_thread_stamp[0] = 1;
_curr_thread_stamp[1] = 1;
_ready_thread_stamp[0] = 0;
_ready_thread_stamp[1] = 0;
_thread_buffer_mask[0] = 0;
_thread_buffer_mask[1] = 0;
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
shm_contexts[i] = nullptr;
thread_shm_ptrs[i] = nullptr;
@ -70,6 +76,11 @@ struct ThreadSHMContext {
set_context(rank, this, thread_shm_ptr);
}
void set_stamp_buffer_idx(int local, int remote) {
local_stamp_buffer_idx = local;
remote_stamp_buffer_idx = remote;
}
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
TORCH_CHECK(ptr);
@ -84,23 +95,27 @@ struct ThreadSHMContext {
T* get_thread_shm_ptr(int rank) {
return reinterpret_cast<T*>(
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
(PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask));
(PER_THREAD_SHM_BUFFER_OFFSET &
_thread_buffer_mask[local_stamp_buffer_idx]));
}
void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; }
void next_buffer() {
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
}
char get_curr_stamp() const { return _curr_thread_stamp; }
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
char get_ready_stamp() const { return _ready_thread_stamp; }
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
void next_stamp() {
_mm_mfence();
_curr_thread_stamp += 1;
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
}
void commit_ready_stamp() {
_mm_mfence();
_ready_thread_stamp = _curr_thread_stamp;
_ready_thread_stamp[local_stamp_buffer_idx] =
_curr_thread_stamp[local_stamp_buffer_idx];
}
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
@ -117,10 +132,11 @@ struct ThreadSHMContext {
void wait_for_one(int rank, Cond&& cond) {
ThreadSHMContext* rank_ctx = shm_contexts[rank];
for (;;) {
char local_curr_stamp = get_curr_stamp();
char local_ready_stamp = get_ready_stamp();
char rank_curr_stamp = rank_ctx->get_curr_stamp();
char rank_ready_stamp = rank_ctx->get_ready_stamp();
char local_curr_stamp = get_curr_stamp(local_stamp_buffer_idx);
char local_ready_stamp = get_ready_stamp(local_stamp_buffer_idx);
char rank_curr_stamp = rank_ctx->get_curr_stamp(remote_stamp_buffer_idx);
char rank_ready_stamp =
rank_ctx->get_ready_stamp(remote_stamp_buffer_idx);
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
rank_ready_stamp)) {
break;
@ -361,6 +377,15 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
}
}
}
void reset_threads_stamp_buffer_idx(ThreadSHMContext* ctx, int local,
int remote) {
int thread_num = ctx->thread_num;
for (int i = 0; i < thread_num; ++i) {
ThreadSHMContext* thread_ctx = ctx + i;
thread_ctx->set_stamp_buffer_idx(local, remote);
}
}
}; // namespace shm_cc_ops
namespace shm_cc_ops {
@ -632,6 +657,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
metadata->bind_tensor_list(tensor_list_with_metadata);
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 0, 1);
shm_cc_ops::shm_cc_loop<int8_t>(
ctx, metadata->total_bytes,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
@ -659,6 +685,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
torch::Tensor metadata_tensor =
torch::empty({sizeof(TensorListMeta)}, options);
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 1, 0);
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
ctx->get_thread_shm_ptr<void>(src),
@ -677,7 +704,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
ctx, metadata.total_bytes,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num, bool fast_mode) {
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
thread_ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
int64_t curr_shm_offset = 0;
while (curr_shm_offset < data_elem_num) {
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);

View File

@ -166,6 +166,20 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe
- This value is 4GB by default. Larger space can support more concurrent requests, longer context length. However, users should take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, the TP worker will be killed with `exitcode 9` due to out-of-memory.
### How to do performance tuning for vLLM CPU?
- First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
- Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
- Offline Inference: `4096 * world_size`
- Online Serving: `2048 * world_size`
- `--max-num-seqs`, defines the limit of sequence numbers in a single batch, has more impacts on the output token performance.
- Offline Inference: `256 * world_size`
- Online Serving: `128 * world_size`
- vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes.
### Which quantization configs does vLLM CPU support?
- vLLM CPU supports quantizations:

View File

@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
from typing import Any, Optional, Union
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.utils import pickle
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
@ -26,7 +27,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
if (current_platform.get_cpu_architecture()
== CpuArchEnum.X86) and hasattr(
torch.ops._C,
"init_shm_manager") and unique_name.startswith("tp"):
"init_shm_manager") and (unique_name.startswith("tp")
or unique_name.startswith("pp")):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(self, input_):
@ -94,6 +96,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
input_size[dim + 1:])
return output_tensor
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
return self.dist_module.send_tensor_dict(tensor_dict, dst)
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
return self.dist_module.recv_tensor_dict(src)
class _CPUSHMDistributed:
@ -143,3 +158,44 @@ class _CPUSHMDistributed:
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
key_list = list(tensor_dict.keys())
value_list = list(tensor_dict.values())
size_list = []
for v in value_list:
if not isinstance(v, torch.Tensor):
raise RuntimeError(
"CpuCommunicator only supports sending tensors.")
size_list.append(v.size())
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
dtype=torch.uint8)
value_list.append(key_size_tensor)
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
return None
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
value_list: list[torch.Tensor] = tensor_list[:-1]
key_size_tensor = tensor_list[-1]
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
key_list = key_size[0]
size_list = key_size[1]
assert len(key_list) == len(size_list)
assert len(key_list) == len(value_list)
tensor_dict: dict[str, torch.Tensor] = {}
for key, size, t in zip(key_list, size_list, value_list):
tensor_dict[key] = t.view(size)
return tensor_dict

View File

@ -272,6 +272,9 @@ class GroupCoordinator:
self.use_custom_op_call = (current_platform.is_cuda_alike()
or current_platform.is_tpu())
self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr(
torch.ops._C, "init_shm_manager"))
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
@ -663,6 +666,11 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
@ -718,6 +726,10 @@ class GroupCoordinator:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
return self.device_communicator.recv_tensor_dict( # type: ignore
src)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:

View File

@ -1639,13 +1639,14 @@ class EngineArgs:
# cpu specific default values.
if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096,
UsageContext.OPENAI_API_SERVER: 2048,
UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 128,
UsageContext.OPENAI_API_SERVER: 32,
UsageContext.LLM_CLASS: 256 * world_size,
UsageContext.OPENAI_API_SERVER: 128 * world_size,
}
use_context_value = usage_context.value if usage_context else None

View File

@ -42,7 +42,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None
VLLM_CPU_MOE_PREPACK: bool = True
@ -430,9 +430,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# (CPU backend only) CPU key-value cache space.
# default is 4 GiB
# default is None and will be set as 4 GB
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0"))
if "VLLM_CPU_KVCACHE_SPACE" in os.environ else None,
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.

View File

@ -104,8 +104,19 @@ class CpuPlatform(Platform):
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
import psutil
return psutil.virtual_memory().total
import vllm.envs as envs
from vllm.utils import GiB_bytes
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
if kv_cache_space is None:
kv_cache_space = 4 * GiB_bytes # type: ignore
logger.warning_once(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
"for CPU backend is not set, using 4 by default.")
else:
kv_cache_space *= GiB_bytes
return kv_cache_space
@classmethod
def set_device(cls, device: torch.device) -> None:
@ -124,8 +135,6 @@ class CpuPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
import vllm.envs as envs
from vllm.utils import GiB_bytes
model_config = vllm_config.model_config
if model_config is not None:
@ -162,20 +171,8 @@ class CpuPlatform(Platform):
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0:
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
"for CPU backend is not set, using 4 by default.")
else:
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
else:
raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
cache_config.cpu_kvcache_space_bytes = \
CpuPlatform.get_device_total_memory()
parallel_config = vllm_config.parallel_config
if (parallel_config.world_size > 1
@ -216,8 +213,6 @@ class CpuPlatform(Platform):
False,
"nan_asserts":
False,
"memory_planning":
True,
"epilogue_fusion":
True,
})