mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
313 lines
9.8 KiB
Python
313 lines
9.8 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Diagnose script for checking OS/hardware/python/pip/verl/network.
|
|
The output of this script can be a very good hint to issue/problem.
|
|
"""
|
|
|
|
import os
|
|
import platform
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
import psutil
|
|
|
|
try:
|
|
from urllib.parse import urlparse
|
|
from urllib.request import urlopen
|
|
except ImportError:
|
|
from urllib2 import urlopen
|
|
from urlparse import urlparse
|
|
import argparse
|
|
import importlib.metadata
|
|
|
|
import torch
|
|
|
|
URLS = {
|
|
"PYPI": "https://pypi.python.org/pypi/pip",
|
|
}
|
|
|
|
REGIONAL_URLS = {
|
|
"cn": {
|
|
"PYPI(douban)": "https://pypi.douban.com/",
|
|
"Conda(tsinghua)": "https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
|
|
}
|
|
}
|
|
|
|
|
|
def test_connection(name, url, timeout=10):
|
|
"""Simple connection test"""
|
|
urlinfo = urlparse(url)
|
|
start = time.time()
|
|
try:
|
|
socket.gethostbyname(urlinfo.netloc)
|
|
except Exception as e:
|
|
print("Error resolving DNS for {}: {}, {}".format(name, url, e))
|
|
return
|
|
dns_elapsed = time.time() - start
|
|
start = time.time()
|
|
try:
|
|
_ = urlopen(url, timeout=timeout)
|
|
except Exception as e:
|
|
print("Error open {}: {}, {}, DNS finished in {} sec.".format(name, url, e, dns_elapsed))
|
|
return
|
|
load_elapsed = time.time() - start
|
|
print("Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.".format(name, url, dns_elapsed, load_elapsed))
|
|
|
|
|
|
def check_python():
|
|
print("----------Python Info----------")
|
|
print("Version :", platform.python_version())
|
|
print("Compiler :", platform.python_compiler())
|
|
print("Build :", platform.python_build())
|
|
print("Arch :", platform.architecture())
|
|
|
|
|
|
def check_pip():
|
|
print("------------Pip Info-----------")
|
|
try:
|
|
import pip
|
|
|
|
print("Version :", pip.__version__)
|
|
print("Directory :", os.path.dirname(pip.__file__))
|
|
except ImportError:
|
|
print("No corresponding pip install for current python.")
|
|
|
|
|
|
def _get_current_git_commit():
|
|
try:
|
|
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
|
|
return result.stdout.strip()
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error running git command: {e.stderr.strip()}")
|
|
return None
|
|
except FileNotFoundError:
|
|
print("Did not find command: git")
|
|
return None
|
|
|
|
|
|
def check_verl():
|
|
print("----------verl Info-----------")
|
|
try:
|
|
sys.path.insert(0, os.getcwd())
|
|
import verl
|
|
|
|
print("Version :", verl.__version__)
|
|
verl_dir = os.path.dirname(verl.__file__)
|
|
print("Directory :", verl_dir)
|
|
try:
|
|
commit_hash = _get_current_git_commit()
|
|
print("Commit Hash :", commit_hash)
|
|
except AttributeError:
|
|
print("Commit hash not found. ")
|
|
except ImportError as e:
|
|
print(f"No verl installed: {e}")
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
if not isinstance(e, IOError):
|
|
print("An error occurred trying to import verl.")
|
|
print("This is very likely due to missing or incompatible library files.")
|
|
print(traceback.format_exc())
|
|
|
|
|
|
def check_os():
|
|
print("----------Platform Info----------")
|
|
print("Platform :", platform.platform())
|
|
print("system :", platform.system())
|
|
print("node :", platform.node())
|
|
print("release :", platform.release())
|
|
print("version :", platform.version())
|
|
|
|
|
|
def check_hardware():
|
|
print("----------Hardware Info----------")
|
|
print("machine :", platform.machine())
|
|
print("processor :", platform.processor())
|
|
if sys.platform.startswith("darwin"):
|
|
pipe = subprocess.Popen(("sysctl", "-a"), stdout=subprocess.PIPE)
|
|
output = pipe.communicate()[0]
|
|
for line in output.split(b"\n"):
|
|
if b"brand_string" in line or b"features" in line:
|
|
print(line.strip())
|
|
elif sys.platform.startswith("linux"):
|
|
subprocess.call(["lscpu"])
|
|
elif sys.platform.startswith("win32"):
|
|
subprocess.call(["wmic", "cpu", "get", "name"])
|
|
|
|
|
|
def check_network(args):
|
|
print("----------Network Test----------")
|
|
if args.timeout > 0:
|
|
print("Setting timeout: {}".format(args.timeout))
|
|
socket.setdefaulttimeout(10)
|
|
for region in args.region.strip().split(","):
|
|
r = region.strip().lower()
|
|
if not r:
|
|
continue
|
|
if r in REGIONAL_URLS:
|
|
URLS.update(REGIONAL_URLS[r])
|
|
else:
|
|
import warnings
|
|
|
|
warnings.warn("Region {} do not need specific test, please refer to global sites.".format(r), stacklevel=2)
|
|
for name, url in URLS.items():
|
|
test_connection(name, url, args.timeout)
|
|
|
|
|
|
def check_environment():
|
|
print("----------Environment----------")
|
|
for k, v in os.environ.items():
|
|
if k.startswith("VERL_") or k.startswith("OMP_") or k.startswith("KMP_") or k == "CC" or k == "CXX":
|
|
print('{}="{}"'.format(k, v))
|
|
|
|
|
|
def check_pip_package_versions():
|
|
packages = ["vllm", "sglang", "ray", "torch"]
|
|
for package in packages:
|
|
try:
|
|
version = importlib.metadata.version(package)
|
|
print(f"{package}\t : {version}")
|
|
except importlib.metadata.PackageNotFoundError:
|
|
print(f"{package}\t : not found.")
|
|
|
|
|
|
def check_cuda_versions():
|
|
if torch.cuda.is_available():
|
|
try:
|
|
cuda_runtime_version = torch.version.cuda
|
|
print(f"CUDA Runtime : {cuda_runtime_version}")
|
|
import subprocess
|
|
|
|
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
|
|
cuda_compiler_version = next((line for line in nvcc_output.splitlines() if "release" in line), None)
|
|
if cuda_compiler_version:
|
|
print(f"CUDA Compiler : {cuda_compiler_version.strip()}")
|
|
else:
|
|
print("Could not determine CUDA compiler version.")
|
|
except FileNotFoundError as e:
|
|
print(f"CUDA compiler : Not found: {e}")
|
|
except Exception as e:
|
|
print(f"An error occurred while checking CUDA versions: {e}")
|
|
else:
|
|
print("CUDA is not available.")
|
|
|
|
|
|
def _get_cpu_memory():
|
|
"""
|
|
Get the total CPU memory capacity in GB.
|
|
"""
|
|
memory = psutil.virtual_memory()
|
|
return memory.total / (1024**3)
|
|
|
|
|
|
def _get_gpu_info():
|
|
"""
|
|
Get GPU type, GPU memory, and GPU count using nvidia-smi command.
|
|
"""
|
|
try:
|
|
result = subprocess.run(
|
|
["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader,nounits"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
)
|
|
gpu_lines = result.stdout.strip().split("\n")
|
|
gpu_count = len(gpu_lines)
|
|
gpu_info = []
|
|
for line in gpu_lines:
|
|
gpu_name, gpu_memory = line.split(", ")
|
|
gpu_info.append(
|
|
{
|
|
"type": gpu_name,
|
|
"memory": float(gpu_memory) / 1024, # Convert to GB
|
|
}
|
|
)
|
|
return gpu_count, gpu_info
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
print("Failed to execute nvidia-smi command.")
|
|
return 0, []
|
|
|
|
|
|
def _get_system_info():
|
|
"""
|
|
Get CPU memory capacity, GPU type, GPU memory, and GPU count.
|
|
"""
|
|
cpu_memory = _get_cpu_memory()
|
|
gpu_count, gpu_info = _get_gpu_info()
|
|
return {"cpu_memory": cpu_memory, "gpu_count": gpu_count, "gpu_info": gpu_info}
|
|
|
|
|
|
def check_system_info():
|
|
print("----------System Info----------")
|
|
system_info = _get_system_info()
|
|
print(f"CPU Memory\t: {system_info['cpu_memory']:.2f} GB")
|
|
print(f"GPU Count\t: {system_info['gpu_count']}")
|
|
for i, gpu in enumerate(system_info["gpu_info"]):
|
|
print(f"GPU {i + 1}\tType : {gpu['type']}")
|
|
print(f"GPU {i + 1}\tMemory : {gpu['memory']:.2f} GB")
|
|
|
|
|
|
def parse_args():
|
|
"""Parse arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
description="Diagnose script for checking the current system.",
|
|
)
|
|
choices = ["python", "pip", "verl", "system", "os", "environment"]
|
|
for choice in choices:
|
|
parser.add_argument("--" + choice, default=1, type=int, help="Diagnose {}.".format(choice))
|
|
parser.add_argument("--network", default=0, type=int, help="Diagnose network.")
|
|
parser.add_argument("--hardware", default=0, type=int, help="Diagnose hardware.")
|
|
parser.add_argument(
|
|
"--region",
|
|
default="",
|
|
type=str,
|
|
help="Additional sites in which region(s) to test. \
|
|
Specify 'cn' for example to test mirror sites in China.",
|
|
)
|
|
parser.add_argument("--timeout", default=10, type=int, help="Connection test timeout threshold, 0 to disable.")
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
if args.python:
|
|
check_python()
|
|
|
|
if args.pip:
|
|
check_pip()
|
|
check_pip_package_versions()
|
|
|
|
if args.verl:
|
|
check_verl()
|
|
|
|
if args.os:
|
|
check_os()
|
|
|
|
if args.hardware:
|
|
check_hardware()
|
|
|
|
if args.network:
|
|
check_network(args)
|
|
|
|
if args.environment:
|
|
check_environment()
|
|
check_cuda_versions()
|
|
|
|
if args.system:
|
|
check_system_info()
|