[CI/Build] Replace lm-eval gsm8k tests with faster implementation (#23002)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -451,13 +451,11 @@ steps:
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: OpenAI API correctness
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
35
tests/evals/gsm8k/README.md
Normal file
35
tests/evals/gsm8k/README.md
Normal file
@ -0,0 +1,35 @@
|
||||
# GSM8K Accuracy Evaluation
|
||||
|
||||
This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run tests with pytest (like buildkite)
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
```
|
||||
|
||||
### Run standalone evaluation script
|
||||
|
||||
```bash
|
||||
# Start vLLM server first
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
|
||||
|
||||
# Run evaluation
|
||||
python tests/gsm8k/gsm8k_eval.py --port 8000
|
||||
```
|
||||
|
||||
## Configuration Format
|
||||
|
||||
Model configs in `configs/` directory use this YAML format:
|
||||
|
||||
```yaml
|
||||
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||
num_questions: 1319 # Number of questions (default: full test set)
|
||||
num_fewshot: 5 # Few-shot examples from train set
|
||||
max_model_len: 4096 # Model context length
|
||||
```
|
2
tests/evals/gsm8k/__init__.py
Normal file
2
tests/evals/gsm8k/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
@ -0,0 +1,5 @@
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
accuracy_threshold: 0.74
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
@ -0,0 +1,5 @@
|
||||
model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
accuracy_threshold: 0.31
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
@ -0,0 +1,5 @@
|
||||
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
accuracy_threshold: 0.60
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-0.6B-FP8"
|
||||
accuracy_threshold: 0.375
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
5
tests/evals/gsm8k/configs/models-small.txt
Normal file
5
tests/evals/gsm8k/configs/models-small.txt
Normal file
@ -0,0 +1,5 @@
|
||||
Qwen3-0.6B-FP8.yaml
|
||||
Llama-3.2-1B-Instruct-INT8-CT.yaml
|
||||
Llama-3-8B-Instruct-nonuniform-CT.yaml
|
||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-CT.yaml
|
66
tests/evals/gsm8k/conftest.py
Normal file
66
tests/evals/gsm8k/conftest.py
Normal file
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption("--config-list-file",
|
||||
default="configs/models-small.txt",
|
||||
help="File containing list of config files to test")
|
||||
parser.addoption("--tp-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Tensor parallel size")
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Generate test parameters from config files."""
|
||||
if "config_filename" in metafunc.fixturenames:
|
||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||
tp_size = metafunc.config.getoption("--tp-size")
|
||||
|
||||
# Handle both relative and absolute paths
|
||||
config_list_path = Path(config_list_file)
|
||||
if not config_list_path.is_absolute():
|
||||
# If relative, try relative to test directory first
|
||||
test_dir_path = Path(__file__).parent / config_list_file
|
||||
if test_dir_path.exists():
|
||||
config_list_path = test_dir_path
|
||||
else:
|
||||
# Try relative to current working directory
|
||||
config_list_path = Path.cwd() / config_list_file
|
||||
|
||||
print(f"Looking for config list at: {config_list_path}")
|
||||
|
||||
config_files = []
|
||||
if config_list_path.exists():
|
||||
# Determine config directory (same directory as the list file)
|
||||
config_dir = config_list_path.parent
|
||||
|
||||
with open(config_list_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
config_path = config_dir / line
|
||||
print(f"Checking config file: {config_path}")
|
||||
if config_path.exists():
|
||||
config_files.append(config_path)
|
||||
print(f" ✓ Found: {config_path}")
|
||||
else:
|
||||
print(f" ✗ Missing: {config_path}")
|
||||
else:
|
||||
print(f"Config list file not found: {config_list_path}")
|
||||
|
||||
# Generate test parameters
|
||||
if config_files:
|
||||
metafunc.parametrize(["config_filename", "tp_size"],
|
||||
[(config_file, int(tp_size))
|
||||
for config_file in config_files],
|
||||
ids=[
|
||||
f"{config_file.stem}-tp{tp_size}"
|
||||
for config_file in config_files
|
||||
])
|
||||
else:
|
||||
print("No config files found, test will be skipped")
|
252
tests/evals/gsm8k/gsm8k_eval.py
Normal file
252
tests/evals/gsm8k/gsm8k_eval.py
Normal file
@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Isolated GSM8K evaluation script for vLLM serve endpoint.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def download_and_cache_file(url: str, filename: Optional[str] = None) -> str:
|
||||
"""Download and cache a file from a URL."""
|
||||
if filename is None:
|
||||
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
print(f"Downloading from {url} to {filename}")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
|
||||
"""Load GSM8K train and test data"""
|
||||
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
|
||||
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
|
||||
train_file = download_and_cache_file(train_url)
|
||||
test_file = download_and_cache_file(test_url)
|
||||
|
||||
train_data = list(read_jsonl(train_file))
|
||||
test_data = list(read_jsonl(test_file))
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
|
||||
def read_jsonl(filename: str) -> Generator[dict, None, None]:
|
||||
"""Read a JSONL file."""
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if not line.startswith("#"):
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def get_answer_value(answer_str: str) -> int:
|
||||
"""Extract the numerical answer from the response."""
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
async def call_vllm_api(session: aiohttp.ClientSession,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: Optional[list[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
seed: Optional[int] = None) -> str:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint."""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
try:
|
||||
async with session.post(f"{url}/v1/completions",
|
||||
json=data) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
return result["choices"][0]["text"]
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def evaluate_gsm8k(num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
host: str = "http://127.0.0.1",
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: Optional[int] = 42) -> dict[str, Union[float, int]]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
|
||||
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||
"""
|
||||
base_url = f"{host}:{port}"
|
||||
|
||||
# Load GSM8K train and test data
|
||||
train_data, test_data = load_gsm8k_data()
|
||||
|
||||
# Limit to available test questions
|
||||
num_questions = min(num_questions, len(test_data))
|
||||
|
||||
# Build few-shot examples from train split (like lm-eval does)
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n")
|
||||
|
||||
# Prepare test questions and labels from test split
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(num_questions):
|
||||
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
|
||||
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||
|
||||
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||
|
||||
# Run evaluation
|
||||
async def run_async_evaluation():
|
||||
states: list[str] = [""] * num_questions
|
||||
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
|
||||
prompt = few_shot_examples + questions[i]
|
||||
answer = await call_vllm_api(
|
||||
session=session,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
url=base_url,
|
||||
seed=seed,
|
||||
)
|
||||
states[i] = answer
|
||||
return answer
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||
total=600)) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
return states
|
||||
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, "
|
||||
f"{num_shots}-shot")
|
||||
|
||||
tic = time.perf_counter()
|
||||
states = asyncio.run(run_async_evaluation())
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute metrics
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
|
||||
result = {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GSM8K evaluation for vLLM serve")
|
||||
parser.add_argument("--num-shots",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of few-shot examples")
|
||||
parser.add_argument("--num-questions",
|
||||
type=int,
|
||||
default=1319,
|
||||
help="Number of questions to evaluate")
|
||||
parser.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Max tokens for generation")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="http://127.0.0.1",
|
||||
help="Host URL")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for generation")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility")
|
||||
parser.add_argument("--save-results",
|
||||
type=str,
|
||||
help="Save results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = evaluate_gsm8k(
|
||||
num_questions=args.num_questions,
|
||||
num_shots=args.num_shots,
|
||||
max_tokens=args.max_tokens,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
temperature=args.temperature,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# Print results to terminal
|
||||
print("\nResults:")
|
||||
print(f"Accuracy: {result['accuracy']:.3f}")
|
||||
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||
print(f"Total latency: {result['latency']:.3f} s")
|
||||
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||
|
||||
# Optional file saving
|
||||
if args.save_results:
|
||||
with open(args.save_results, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"Results saved to {args.save_results}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
90
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
90
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
GSM8K evaluation using vLLM server and isolated GSM8K script.
|
||||
Replacement for lm-eval-harness with better performance and control.
|
||||
|
||||
Usage:
|
||||
pytest -s -v test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
"""
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .gsm8k_eval import evaluate_gsm8k
|
||||
|
||||
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
||||
|
||||
|
||||
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
"""Launch GSM8K evaluation using our isolated script."""
|
||||
# Extract host and port from server URL
|
||||
if "://" in server_url:
|
||||
server_url = server_url.split("://")[1]
|
||||
|
||||
host_port = server_url.split("/")[0] # Remove path if present
|
||||
if ":" in host_port:
|
||||
host, port = host_port.split(":")
|
||||
port = int(port)
|
||||
else:
|
||||
host = host_port
|
||||
port = 8000
|
||||
|
||||
# Add http:// prefix if not present
|
||||
if not host.startswith("http"):
|
||||
host = f"http://{host}"
|
||||
|
||||
# Run GSM8K evaluation
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=eval_config["num_questions"],
|
||||
num_shots=eval_config["num_fewshot"],
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_gsm8k_correctness_param(config_filename, tp_size):
|
||||
"""Test GSM8K correctness for a given model configuration."""
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
str(eval_config.get("max_model_len", 4096)),
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
]
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(eval_config["model_name"],
|
||||
server_args,
|
||||
max_wait_seconds=480) as remote_server:
|
||||
server_url = remote_server.url_for("v1")
|
||||
|
||||
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
||||
|
||||
# Check accuracy against threshold
|
||||
measured_accuracy = results["accuracy"]
|
||||
expected_accuracy = eval_config["accuracy_threshold"]
|
||||
|
||||
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||
print(f" Accuracy: {measured_accuracy:.3f}")
|
||||
print(f" Expected: {expected_accuracy:.3f}")
|
||||
print(f" Questions: {results['num_questions']}")
|
||||
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||
print(f" Latency: {results['latency']:.1f}s")
|
||||
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||
|
||||
# Verify accuracy is within tolerance
|
||||
assert measured_accuracy >= expected_accuracy - RTOL, (
|
||||
f"Accuracy too low: {measured_accuracy:.3f} < "
|
||||
f"{expected_accuracy:.3f} - {RTOL:.3f}")
|
||||
|
||||
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
Reference in New Issue
Block a user