[bugfix] Fix auto thread-binding when world_size > 1 in CPU backend and refactor code (#21032)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@ -24,8 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
|
||||
@ -94,8 +94,8 @@ Currently, there are no pre-built CPU wheels.
|
||||
## Related runtime environment variables
|
||||
|
||||
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`.
|
||||
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
|
||||
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
|
||||
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
|
||||
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
|
||||
|
||||
@ -123,9 +123,13 @@ export VLLM_CPU_NUM_OF_RESERVED_CPU=1
|
||||
vllm serve facebook/opt-125m --dtype=bfloat16
|
||||
```
|
||||
|
||||
Note, it is recommended to manually reserve 1 CPU for vLLM front-end process when `world_size == 1`.
|
||||
|
||||
### How to decide `VLLM_CPU_OMP_THREADS_BIND`?
|
||||
|
||||
- Bind each OpenMP thread to a dedicated physical CPU core respectively, or use auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following.
|
||||
|
||||
- On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
|
||||
??? console "Commands"
|
||||
|
||||
|
||||
@ -24,6 +24,4 @@ datasets # for benchmark scripts
|
||||
# Intel Extension for PyTorch, only for x86_64 CPUs
|
||||
intel-openmp==2024.2.1; platform_machine == "x86_64"
|
||||
intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
|
||||
py-libnuma; platform_system != "Darwin"
|
||||
psutil; platform_system != "Darwin"
|
||||
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
|
||||
|
||||
@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None
|
||||
VLLM_CPU_MOE_PREPACK: bool = True
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
@ -442,7 +442,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# (CPU backend only) CPU cores not used by OMP threads .
|
||||
# Those CPU cores will not be used by OMP threads of a rank.
|
||||
"VLLM_CPU_NUM_OF_RESERVED_CPU":
|
||||
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")),
|
||||
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
|
||||
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None,
|
||||
|
||||
# (CPU backend only) whether to use prepack for MoE layer. This will be
|
||||
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
@ -31,6 +34,35 @@ def get_max_threads(pid=0):
|
||||
raise NotImplementedError("Unsupported OS")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogicalCPUInfo:
|
||||
id: int = -1
|
||||
physical_core: int = -1
|
||||
numa_node: int = -1
|
||||
|
||||
@classmethod
|
||||
def _int(cls, value: str) -> int:
|
||||
try:
|
||||
int_value = int(value)
|
||||
except Exception:
|
||||
int_value = -1
|
||||
return int_value
|
||||
|
||||
@staticmethod
|
||||
def json_decoder(obj_dict: dict):
|
||||
id = obj_dict.get("cpu")
|
||||
physical_core = obj_dict.get("core")
|
||||
numa_node = obj_dict.get("node")
|
||||
|
||||
if not (id is None or physical_core is None or numa_node is None):
|
||||
return LogicalCPUInfo(
|
||||
id=LogicalCPUInfo._int(id),
|
||||
physical_core=LogicalCPUInfo._int(physical_core),
|
||||
numa_node=LogicalCPUInfo._int(numa_node))
|
||||
else:
|
||||
return obj_dict
|
||||
|
||||
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
device_name: str = "cpu"
|
||||
@ -240,6 +272,38 @@ class CpuPlatform(Platform):
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def get_allowed_cpu_memory_node_list(
|
||||
cls) -> tuple[list[int], list[LogicalCPUInfo]]:
|
||||
assert platform.system() == "Linux"
|
||||
|
||||
# Init LogicalCPUInfo from lscpu
|
||||
lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE",
|
||||
shell=True,
|
||||
text=True)
|
||||
logical_cpu_list: list[LogicalCPUInfo] = json.loads(
|
||||
lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus']
|
||||
|
||||
# Filter CPUs with invalid attributes
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list
|
||||
if -1 not in (x.id, x.physical_core, x.numa_node)
|
||||
]
|
||||
|
||||
# Filter allowed CPUs
|
||||
allowed_cpu_id_list = os.sched_getaffinity(0)
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list if x.id in allowed_cpu_id_list
|
||||
]
|
||||
|
||||
# Get allowed NUMA nodes
|
||||
allowed_numa_nodes = set()
|
||||
for x in logical_cpu_list:
|
||||
allowed_numa_nodes.add(x.numa_node) # type: ignore
|
||||
allowed_numa_nodes_list = sorted(allowed_numa_nodes)
|
||||
|
||||
return allowed_numa_nodes_list, logical_cpu_list
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
logger.warning("Pin memory is not supported on CPU.")
|
||||
|
||||
@ -45,9 +45,10 @@ class CPUModelRunner(GPUModelRunner):
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for k, v in vars(self.input_batch.block_table).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch.block_table, k, k[:-4])
|
||||
for block_table in self.input_batch.block_table.block_tables:
|
||||
for k, v in vars(block_table).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(block_table, k, k[:-4])
|
||||
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from importlib import util
|
||||
from typing import Optional
|
||||
import platform
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,21 +12,14 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import PlaceholderModule
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
|
||||
try:
|
||||
import psutil
|
||||
from numa import info
|
||||
except ImportError:
|
||||
psutil = PlaceholderModule("psutil") # type: ignore[assignment]
|
||||
numa = PlaceholderModule("numa") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -45,20 +38,21 @@ class CPUWorker(Worker):
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
self.manually_bind_threads_suggestion = (
|
||||
"To get better performance, please try to manually bind threads.")
|
||||
|
||||
def init_device(self):
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
self.local_omp_cpuid = "all"
|
||||
if omp_cpuids == "auto":
|
||||
if omp_cpuids == "auto" and platform.system() == "Linux":
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
|
||||
self.local_omp_cpuid = (
|
||||
self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
|
||||
# For POWERPC SMT-8/4/2
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4])
|
||||
elif current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
# For x86 SMT-2, use 1 CPU per core
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||
lambda cpus: cpus[-1:])
|
||||
else:
|
||||
self.local_omp_cpuid = (
|
||||
self.get_cpus_id_binding_based_on_numa_nodes())
|
||||
self.local_omp_cpuid = "all"
|
||||
else:
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
|
||||
|
||||
@ -122,126 +116,58 @@ class CPUWorker(Worker):
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def warn_inability_to_detect_numa(self) -> None:
|
||||
logger.warning(
|
||||
"Auto thread-binding failed due to the "
|
||||
"inability to detect numa nodes. %s",
|
||||
self.manually_bind_threads_suggestion)
|
||||
|
||||
def warn_lack_of_numa_and_psutil(self) -> None:
|
||||
logger.warning(
|
||||
"Auto thread-binding failed due to "
|
||||
"the lack of package numa and psutil. %s",
|
||||
self.manually_bind_threads_suggestion)
|
||||
|
||||
def warn_world_size_too_large(self, world_size: int,
|
||||
node_to_cpus_len: int) -> None:
|
||||
logger.warning(
|
||||
"Auto thread-binding failed due to "
|
||||
"world size: %d being larger than "
|
||||
"allowed NUMA nodes number: %d. %s", world_size, node_to_cpus_len,
|
||||
self.manually_bind_threads_suggestion)
|
||||
|
||||
def get_cpus_allow_list_and_numa_size(self):
|
||||
cpus_allow_list = psutil.Process().cpu_affinity()
|
||||
numa_size = info.get_num_configured_nodes()
|
||||
return cpus_allow_list, numa_size
|
||||
|
||||
def auto_thread_binding_based_on_numa_nodes(self, world_size: int,
|
||||
rank_to_cpus: str) -> str:
|
||||
cpu_count = psutil.cpu_count(logical=False)
|
||||
cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
|
||||
if not numa_size:
|
||||
self.warn_inability_to_detect_numa()
|
||||
return rank_to_cpus
|
||||
|
||||
cpu_count_per_numa = cpu_count // numa_size
|
||||
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
|
||||
cpu_count_per_numa // 2)
|
||||
|
||||
node_to_cpus = []
|
||||
for i in range(numa_size):
|
||||
node_intersect = set(
|
||||
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||
if bool(node_intersect):
|
||||
node_to_cpus.append(list(node_intersect))
|
||||
|
||||
node_to_cpus_len = len(node_to_cpus)
|
||||
if world_size > node_to_cpus_len:
|
||||
self.warn_world_size_too_large(world_size, node_to_cpus_len)
|
||||
else:
|
||||
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||
rank_to_cpus_list = node_to_cpus[self.rank][:end]
|
||||
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||
logger.info("auto thread-binding list: %s", rank_to_cpus)
|
||||
return rank_to_cpus
|
||||
|
||||
def libnuma_and_psutil_found(self) -> bool:
|
||||
libnuma_found = util.find_spec("numa") is not None
|
||||
psutil_found = util.find_spec("psutil") is not None
|
||||
|
||||
return libnuma_found and psutil_found
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||
"""Return CPUs id binding based on NUMA nodes.
|
||||
def _get_autobind_cpu_ids(
|
||||
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
|
||||
list[LogicalCPUInfo]]
|
||||
) -> str:
|
||||
"""
|
||||
rank_to_cpus = self.local_omp_cpuid
|
||||
# Setup OpenMP thread affinity based on NUMA nodes automatically
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
if self.libnuma_and_psutil_found():
|
||||
rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes(
|
||||
world_size, rank_to_cpus)
|
||||
else:
|
||||
self.warn_lack_of_numa_and_psutil()
|
||||
return rank_to_cpus
|
||||
|
||||
def select_threads_per_power_core(self,
|
||||
node_cpu_ids: list[int]) -> list[int]:
|
||||
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
|
||||
|
||||
def auto_thread_binding_based_on_numa_nodes_ppc64le(
|
||||
self, world_size: int, rank_to_cpus: str) -> str:
|
||||
cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
|
||||
if not numa_size:
|
||||
self.warn_inability_to_detect_numa()
|
||||
return rank_to_cpus
|
||||
|
||||
node_to_cpus = []
|
||||
for i in range(numa_size):
|
||||
node_intersect = set(
|
||||
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||
if bool(node_intersect):
|
||||
node_to_cpus.append(sorted(list(node_intersect)))
|
||||
|
||||
node_to_cpus_len = len(node_to_cpus)
|
||||
if world_size > node_to_cpus_len:
|
||||
self.warn_world_size_too_large(world_size, node_to_cpus_len)
|
||||
else:
|
||||
node_cpus_this_rank = node_to_cpus[self.rank]
|
||||
node_cpus_this_rank = self.select_threads_per_power_core(
|
||||
node_cpus_this_rank)
|
||||
cpu_count_per_numa = len(node_cpus_this_rank)
|
||||
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
|
||||
cpu_count_per_numa // 2)
|
||||
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||
rank_to_cpus_list = node_cpus_this_rank[:end]
|
||||
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
|
||||
return rank_to_cpus
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
|
||||
"""
|
||||
Power (ppc64le) specific: Selects a subset of threads per core for
|
||||
each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc)
|
||||
because the OS only exposes available threads.This maximizes
|
||||
performance by avoiding oversubscription of logical CPUs on Power.
|
||||
Return CPU ids to bind based on NUMA nodes.
|
||||
Currently for rank N, only CPU ids on the N-th node in available NUMA
|
||||
node list will be selected.
|
||||
Args:
|
||||
cpu_selector: a callable object to select CPUs from a CPU list
|
||||
of a physical core. The input is a LogicalCPUInfo list, sorted by
|
||||
the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
|
||||
returned.
|
||||
"""
|
||||
|
||||
rank_to_cpus = self.local_omp_cpuid
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
if self.libnuma_and_psutil_found():
|
||||
rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes_ppc64le(
|
||||
world_size, rank_to_cpus)
|
||||
else:
|
||||
self.warn_lack_of_numa_and_psutil()
|
||||
return rank_to_cpus
|
||||
allowed_numa_nodes, logical_cpu_list = \
|
||||
CpuPlatform.get_allowed_cpu_memory_node_list()
|
||||
assert len(allowed_numa_nodes) >= self.parallel_config.world_size, (
|
||||
f"No enough allowed NUMA nodes to bind threads of "
|
||||
f"{self.parallel_config.world_size} CPUWorkers. "
|
||||
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
|
||||
"Please try to bind threads manually.")
|
||||
|
||||
# Get CPUs on NUMA node `allowed_numa_nodes[local_rank]``
|
||||
selected_numa_node = allowed_numa_nodes[
|
||||
self.local_rank] # type: ignore
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list if x.numa_node == selected_numa_node
|
||||
]
|
||||
|
||||
# Select CPUs from each physical core via cpu_selector
|
||||
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
|
||||
for cpu_info in logical_cpu_list:
|
||||
if cpu_info.physical_core not in core_to_cpus:
|
||||
core_to_cpus[cpu_info.physical_core] = []
|
||||
core_to_cpus[cpu_info.physical_core].append(cpu_info)
|
||||
logical_cpu_list = []
|
||||
for cpu_list in core_to_cpus.values():
|
||||
cpu_list = sorted(cpu_list, key=lambda x: x.id)
|
||||
logical_cpu_list.extend(cpu_selector(cpu_list))
|
||||
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)
|
||||
|
||||
# Reserve CPUs for other processes
|
||||
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
|
||||
if reserve_cpu_num is None:
|
||||
reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0
|
||||
assert len(logical_cpu_list) > reserve_cpu_num, (
|
||||
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
|
||||
f"should less than {len(logical_cpu_list)}.")
|
||||
if reserve_cpu_num != 0:
|
||||
logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]
|
||||
|
||||
logger.info("auto thread-binding list (id, physical core): %s",
|
||||
[(x.id, x.physical_core) for x in logical_cpu_list])
|
||||
return ",".join([str(x.id) for x in logical_cpu_list])
|
||||
|
||||
Reference in New Issue
Block a user