mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][Distributed] fix _is_full_nvlink detection (#4233)
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -53,14 +54,20 @@ def init_custom_ar() -> None:
|
||||
return False
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
full_nvlink = _is_full_nvlink(rank, world_size)
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
device_ids = list(
|
||||
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
|
||||
else:
|
||||
device_ids = list(range(num_dev))
|
||||
# this checks hardware and driver support for NVLink
|
||||
full_nvlink = _is_full_nvlink(device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because it's not supported on more"
|
||||
" than two PCIe-only GPUs. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
# test P2P capability
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
@ -138,23 +145,28 @@ def _nvml():
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
|
||||
# query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
@_nvml()
|
||||
def _is_full_nvlink(rank, world_size):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
|
||||
for i in range(world_size):
|
||||
if i != rank:
|
||||
try:
|
||||
peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
def _is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
so it works on real physical device ids.
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.error(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped.",
|
||||
exc_info=error)
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.info(
|
||||
f"NVLink detection failed with message \"{str(error)}\". "
|
||||
"This is normal if your machine has no NVLink equipped")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user