[bugfix] Fix auto thread-binding when world_size > 1 in CPU backend and refactor code (#21032)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-07-19 20:13:55 +08:00
committed by GitHub
parent b3d82108e7
commit e3a0e43d7f
7 changed files with 144 additions and 150 deletions

View File

@ -24,8 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
function cpu_tests() {
set -e

View File

@ -94,8 +94,8 @@ Currently, there are no pre-built CPU wheels.
## Related runtime environment variables
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
@ -123,9 +123,13 @@ export VLLM_CPU_NUM_OF_RESERVED_CPU=1
vllm serve facebook/opt-125m --dtype=bfloat16
```
Note, it is recommended to manually reserve 1 CPU for vLLM front-end process when `world_size == 1`.
### How to decide `VLLM_CPU_OMP_THREADS_BIND`?
- Bind each OpenMP thread to a dedicated physical CPU core respectively, or use auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following.
- On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
??? console "Commands"

View File

@ -24,6 +24,4 @@ datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
py-libnuma; platform_system != "Darwin"
psutil; platform_system != "Darwin"
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.

View File

@ -44,7 +44,7 @@ if TYPE_CHECKING:
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None
VLLM_CPU_MOE_PREPACK: bool = True
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
@ -442,7 +442,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# (CPU backend only) CPU cores not used by OMP threads .
# Those CPU cores will not be used by OMP threads of a rank.
"VLLM_CPU_NUM_OF_RESERVED_CPU":
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")),
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,
# (CPU backend only) whether to use prepack for MoE layer. This will be
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import platform
import subprocess
import sys
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Optional
@ -31,6 +34,35 @@ def get_max_threads(pid=0):
raise NotImplementedError("Unsupported OS")
@dataclass
class LogicalCPUInfo:
id: int = -1
physical_core: int = -1
numa_node: int = -1
@classmethod
def _int(cls, value: str) -> int:
try:
int_value = int(value)
except Exception:
int_value = -1
return int_value
@staticmethod
def json_decoder(obj_dict: dict):
id = obj_dict.get("cpu")
physical_core = obj_dict.get("core")
numa_node = obj_dict.get("node")
if not (id is None or physical_core is None or numa_node is None):
return LogicalCPUInfo(
id=LogicalCPUInfo._int(id),
physical_core=LogicalCPUInfo._int(physical_core),
numa_node=LogicalCPUInfo._int(numa_node))
else:
return obj_dict
class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_name: str = "cpu"
@ -240,6 +272,38 @@ class CpuPlatform(Platform):
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def get_allowed_cpu_memory_node_list(
cls) -> tuple[list[int], list[LogicalCPUInfo]]:
assert platform.system() == "Linux"
# Init LogicalCPUInfo from lscpu
lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE",
shell=True,
text=True)
logical_cpu_list: list[LogicalCPUInfo] = json.loads(
lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus']
# Filter CPUs with invalid attributes
logical_cpu_list = [
x for x in logical_cpu_list
if -1 not in (x.id, x.physical_core, x.numa_node)
]
# Filter allowed CPUs
allowed_cpu_id_list = os.sched_getaffinity(0)
logical_cpu_list = [
x for x in logical_cpu_list if x.id in allowed_cpu_id_list
]
# Get allowed NUMA nodes
allowed_numa_nodes = set()
for x in logical_cpu_list:
allowed_numa_nodes.add(x.numa_node) # type: ignore
allowed_numa_nodes_list = sorted(allowed_numa_nodes)
return allowed_numa_nodes_list, logical_cpu_list
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")

View File

@ -45,9 +45,10 @@ class CPUModelRunner(GPUModelRunner):
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch, k, k[:-11])
for k, v in vars(self.input_batch.block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch.block_table, k, k[:-4])
for block_table in self.input_batch.block_table.block_tables:
for k, v in vars(block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(block_table, k, k[:-4])
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from importlib import util
from typing import Optional
import platform
from typing import Callable, Optional
import torch
@ -12,21 +12,14 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
from vllm.sequence import IntermediateTensors
from vllm.utils import PlaceholderModule
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
try:
import psutil
from numa import info
except ImportError:
psutil = PlaceholderModule("psutil") # type: ignore[assignment]
numa = PlaceholderModule("numa") # type: ignore[assignment]
logger = init_logger(__name__)
@ -45,20 +38,21 @@ class CPUWorker(Worker):
is_driver_worker=is_driver_worker)
self.parallel_config.disable_custom_all_reduce = True
self.manually_bind_threads_suggestion = (
"To get better performance, please try to manually bind threads.")
def init_device(self):
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
if omp_cpuids == "auto" and platform.system() == "Linux":
if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
# For POWERPC SMT-8/4/2
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4])
elif current_platform.get_cpu_architecture() == CpuArchEnum.X86:
# For x86 SMT-2, use 1 CPU per core
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: cpus[-1:])
else:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes())
self.local_omp_cpuid = "all"
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
@ -122,126 +116,58 @@ class CPUWorker(Worker):
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def warn_inability_to_detect_numa(self) -> None:
logger.warning(
"Auto thread-binding failed due to the "
"inability to detect numa nodes. %s",
self.manually_bind_threads_suggestion)
def warn_lack_of_numa_and_psutil(self) -> None:
logger.warning(
"Auto thread-binding failed due to "
"the lack of package numa and psutil. %s",
self.manually_bind_threads_suggestion)
def warn_world_size_too_large(self, world_size: int,
node_to_cpus_len: int) -> None:
logger.warning(
"Auto thread-binding failed due to "
"world size: %d being larger than "
"allowed NUMA nodes number: %d. %s", world_size, node_to_cpus_len,
self.manually_bind_threads_suggestion)
def get_cpus_allow_list_and_numa_size(self):
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
return cpus_allow_list, numa_size
def auto_thread_binding_based_on_numa_nodes(self, world_size: int,
rank_to_cpus: str) -> str:
cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
if not numa_size:
self.warn_inability_to_detect_numa()
return rank_to_cpus
cpu_count_per_numa = cpu_count // numa_size
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(list(node_intersect))
node_to_cpus_len = len(node_to_cpus)
if world_size > node_to_cpus_len:
self.warn_world_size_too_large(world_size, node_to_cpus_len)
else:
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_to_cpus[self.rank][:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("auto thread-binding list: %s", rank_to_cpus)
return rank_to_cpus
def libnuma_and_psutil_found(self) -> bool:
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
return libnuma_found and psutil_found
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return CPUs id binding based on NUMA nodes.
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
list[LogicalCPUInfo]]
) -> str:
"""
rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
if self.libnuma_and_psutil_found():
rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes(
world_size, rank_to_cpus)
else:
self.warn_lack_of_numa_and_psutil()
return rank_to_cpus
def select_threads_per_power_core(self,
node_cpu_ids: list[int]) -> list[int]:
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
def auto_thread_binding_based_on_numa_nodes_ppc64le(
self, world_size: int, rank_to_cpus: str) -> str:
cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
if not numa_size:
self.warn_inability_to_detect_numa()
return rank_to_cpus
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(sorted(list(node_intersect)))
node_to_cpus_len = len(node_to_cpus)
if world_size > node_to_cpus_len:
self.warn_world_size_too_large(world_size, node_to_cpus_len)
else:
node_cpus_this_rank = node_to_cpus[self.rank]
node_cpus_this_rank = self.select_threads_per_power_core(
node_cpus_this_rank)
cpu_count_per_numa = len(node_cpus_this_rank)
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_cpus_this_rank[:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
return rank_to_cpus
def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
"""
Power (ppc64le) specific: Selects a subset of threads per core for
each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc)
because the OS only exposes available threads.This maximizes
performance by avoiding oversubscription of logical CPUs on Power.
Return CPU ids to bind based on NUMA nodes.
Currently for rank N, only CPU ids on the N-th node in available NUMA
node list will be selected.
Args:
cpu_selector: a callable object to select CPUs from a CPU list
of a physical core. The input is a LogicalCPUInfo list, sorted by
the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
returned.
"""
rank_to_cpus = self.local_omp_cpuid
world_size = self.vllm_config.parallel_config.world_size
if self.libnuma_and_psutil_found():
rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes_ppc64le(
world_size, rank_to_cpus)
else:
self.warn_lack_of_numa_and_psutil()
return rank_to_cpus
allowed_numa_nodes, logical_cpu_list = \
CpuPlatform.get_allowed_cpu_memory_node_list()
assert len(allowed_numa_nodes) >= self.parallel_config.world_size, (
f"No enough allowed NUMA nodes to bind threads of "
f"{self.parallel_config.world_size} CPUWorkers. "
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
"Please try to bind threads manually.")
# Get CPUs on NUMA node `allowed_numa_nodes[local_rank]``
selected_numa_node = allowed_numa_nodes[
self.local_rank] # type: ignore
logical_cpu_list = [
x for x in logical_cpu_list if x.numa_node == selected_numa_node
]
# Select CPUs from each physical core via cpu_selector
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
for cpu_info in logical_cpu_list:
if cpu_info.physical_core not in core_to_cpus:
core_to_cpus[cpu_info.physical_core] = []
core_to_cpus[cpu_info.physical_core].append(cpu_info)
logical_cpu_list = []
for cpu_list in core_to_cpus.values():
cpu_list = sorted(cpu_list, key=lambda x: x.id)
logical_cpu_list.extend(cpu_selector(cpu_list))
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)
# Reserve CPUs for other processes
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
if reserve_cpu_num is None:
reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0
assert len(logical_cpu_list) > reserve_cpu_num, (
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
f"should less than {len(logical_cpu_list)}.")
if reserve_cpu_num != 0:
logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]
logger.info("auto thread-binding list (id, physical core): %s",
[(x.id, x.physical_core) for x in logical_cpu_list])
return ",".join([str(x.id) for x in logical_cpu_list])