[CI/Build] Replace lm-eval gsm8k tests with faster implementation (#23002)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -451,13 +451,11 @@ steps:
|
|||||||
|
|
||||||
- label: LM Eval Small Models # 53min
|
- label: LM Eval Small Models # 53min
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
commands:
|
commands:
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
|
||||||
|
|
||||||
- label: OpenAI API correctness
|
- label: OpenAI API correctness
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
|
35
tests/evals/gsm8k/README.md
Normal file
35
tests/evals/gsm8k/README.md
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# GSM8K Accuracy Evaluation
|
||||||
|
|
||||||
|
This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Run tests with pytest (like buildkite)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||||
|
--config-list-file=configs/models-small.txt \
|
||||||
|
--tp-size=1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run standalone evaluation script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start vLLM server first
|
||||||
|
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
|
||||||
|
|
||||||
|
# Run evaluation
|
||||||
|
python tests/gsm8k/gsm8k_eval.py --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Format
|
||||||
|
|
||||||
|
Model configs in `configs/` directory use this YAML format:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||||
|
num_questions: 1319 # Number of questions (default: full test set)
|
||||||
|
num_fewshot: 5 # Few-shot examples from train set
|
||||||
|
max_model_len: 4096 # Model context length
|
||||||
|
```
|
2
tests/evals/gsm8k/__init__.py
Normal file
2
tests/evals/gsm8k/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
@ -0,0 +1,5 @@
|
|||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||||
|
accuracy_threshold: 0.74
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
max_model_len: 4096
|
@ -0,0 +1,5 @@
|
|||||||
|
model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||||
|
accuracy_threshold: 0.31
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
max_model_len: 4096
|
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||||
|
accuracy_threshold: 0.45
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
max_model_len: 4096
|
@ -0,0 +1,5 @@
|
|||||||
|
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||||
|
accuracy_threshold: 0.60
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
max_model_len: 4096
|
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
model_name: "Qwen/Qwen3-0.6B-FP8"
|
||||||
|
accuracy_threshold: 0.375
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
max_model_len: 4096
|
5
tests/evals/gsm8k/configs/models-small.txt
Normal file
5
tests/evals/gsm8k/configs/models-small.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
Qwen3-0.6B-FP8.yaml
|
||||||
|
Llama-3.2-1B-Instruct-INT8-CT.yaml
|
||||||
|
Llama-3-8B-Instruct-nonuniform-CT.yaml
|
||||||
|
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||||
|
Qwen1.5-MoE-W4A16-CT.yaml
|
66
tests/evals/gsm8k/conftest.py
Normal file
66
tests/evals/gsm8k/conftest.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
"""Add custom command line options."""
|
||||||
|
parser.addoption("--config-list-file",
|
||||||
|
default="configs/models-small.txt",
|
||||||
|
help="File containing list of config files to test")
|
||||||
|
parser.addoption("--tp-size",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="Tensor parallel size")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
"""Generate test parameters from config files."""
|
||||||
|
if "config_filename" in metafunc.fixturenames:
|
||||||
|
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||||
|
tp_size = metafunc.config.getoption("--tp-size")
|
||||||
|
|
||||||
|
# Handle both relative and absolute paths
|
||||||
|
config_list_path = Path(config_list_file)
|
||||||
|
if not config_list_path.is_absolute():
|
||||||
|
# If relative, try relative to test directory first
|
||||||
|
test_dir_path = Path(__file__).parent / config_list_file
|
||||||
|
if test_dir_path.exists():
|
||||||
|
config_list_path = test_dir_path
|
||||||
|
else:
|
||||||
|
# Try relative to current working directory
|
||||||
|
config_list_path = Path.cwd() / config_list_file
|
||||||
|
|
||||||
|
print(f"Looking for config list at: {config_list_path}")
|
||||||
|
|
||||||
|
config_files = []
|
||||||
|
if config_list_path.exists():
|
||||||
|
# Determine config directory (same directory as the list file)
|
||||||
|
config_dir = config_list_path.parent
|
||||||
|
|
||||||
|
with open(config_list_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#"):
|
||||||
|
config_path = config_dir / line
|
||||||
|
print(f"Checking config file: {config_path}")
|
||||||
|
if config_path.exists():
|
||||||
|
config_files.append(config_path)
|
||||||
|
print(f" ✓ Found: {config_path}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Missing: {config_path}")
|
||||||
|
else:
|
||||||
|
print(f"Config list file not found: {config_list_path}")
|
||||||
|
|
||||||
|
# Generate test parameters
|
||||||
|
if config_files:
|
||||||
|
metafunc.parametrize(["config_filename", "tp_size"],
|
||||||
|
[(config_file, int(tp_size))
|
||||||
|
for config_file in config_files],
|
||||||
|
ids=[
|
||||||
|
f"{config_file.stem}-tp{tp_size}"
|
||||||
|
for config_file in config_files
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
print("No config files found, test will be skipped")
|
252
tests/evals/gsm8k/gsm8k_eval.py
Normal file
252
tests/evals/gsm8k/gsm8k_eval.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Isolated GSM8K evaluation script for vLLM serve endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
import regex as re
|
||||||
|
import requests
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_cache_file(url: str, filename: Optional[str] = None) -> str:
|
||||||
|
"""Download and cache a file from a URL."""
|
||||||
|
if filename is None:
|
||||||
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||||
|
|
||||||
|
if os.path.exists(filename):
|
||||||
|
return filename
|
||||||
|
|
||||||
|
print(f"Downloading from {url} to {filename}")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
|
||||||
|
"""Load GSM8K train and test data"""
|
||||||
|
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
|
||||||
|
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||||
|
|
||||||
|
train_file = download_and_cache_file(train_url)
|
||||||
|
test_file = download_and_cache_file(test_url)
|
||||||
|
|
||||||
|
train_data = list(read_jsonl(train_file))
|
||||||
|
test_data = list(read_jsonl(test_file))
|
||||||
|
|
||||||
|
return train_data, test_data
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl(filename: str) -> Generator[dict, None, None]:
|
||||||
|
"""Read a JSONL file."""
|
||||||
|
with open(filename) as fin:
|
||||||
|
for line in fin:
|
||||||
|
if not line.startswith("#"):
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_value(answer_str: str) -> int:
|
||||||
|
"""Extract the numerical answer from the response."""
|
||||||
|
answer_str = answer_str.replace(",", "")
|
||||||
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
|
if len(numbers) < 1:
|
||||||
|
return INVALID
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(numbers[-1])
|
||||||
|
except SyntaxError:
|
||||||
|
return INVALID
|
||||||
|
|
||||||
|
|
||||||
|
async def call_vllm_api(session: aiohttp.ClientSession,
|
||||||
|
prompt: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None) -> str:
|
||||||
|
"""Call vLLM's OpenAI-compatible completions endpoint."""
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stop": stop,
|
||||||
|
}
|
||||||
|
if seed is not None:
|
||||||
|
data["seed"] = seed
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(f"{url}/v1/completions",
|
||||||
|
json=data) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
result = await response.json()
|
||||||
|
return result["choices"][0]["text"]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error calling vLLM API: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_gsm8k(num_questions: int = 1319,
|
||||||
|
num_shots: int = 5,
|
||||||
|
max_tokens: int = 256,
|
||||||
|
host: str = "http://127.0.0.1",
|
||||||
|
port: int = 8000,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
seed: Optional[int] = 42) -> dict[str, Union[float, int]]:
|
||||||
|
"""
|
||||||
|
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||||
|
|
||||||
|
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||||
|
"""
|
||||||
|
base_url = f"{host}:{port}"
|
||||||
|
|
||||||
|
# Load GSM8K train and test data
|
||||||
|
train_data, test_data = load_gsm8k_data()
|
||||||
|
|
||||||
|
# Limit to available test questions
|
||||||
|
num_questions = min(num_questions, len(test_data))
|
||||||
|
|
||||||
|
# Build few-shot examples from train split (like lm-eval does)
|
||||||
|
few_shot_examples = ""
|
||||||
|
for i in range(num_shots):
|
||||||
|
few_shot_examples += (f"Question: {train_data[i]['question']}\n"
|
||||||
|
f"Answer: {train_data[i]['answer']}\n\n")
|
||||||
|
|
||||||
|
# Prepare test questions and labels from test split
|
||||||
|
questions = []
|
||||||
|
labels = []
|
||||||
|
for i in range(num_questions):
|
||||||
|
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
|
||||||
|
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||||
|
|
||||||
|
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||||
|
|
||||||
|
# Run evaluation
|
||||||
|
async def run_async_evaluation():
|
||||||
|
states: list[str] = [""] * num_questions
|
||||||
|
|
||||||
|
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
|
||||||
|
prompt = few_shot_examples + questions[i]
|
||||||
|
answer = await call_vllm_api(
|
||||||
|
session=session,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop=["Question", "Assistant:", "<|separator|>"],
|
||||||
|
url=base_url,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
states[i] = answer
|
||||||
|
return answer
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||||
|
total=600)) as session:
|
||||||
|
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||||
|
await tqdm.gather(*tasks, desc="Evaluating")
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
print(f"Running GSM8K evaluation: {num_questions} questions, "
|
||||||
|
f"{num_shots}-shot")
|
||||||
|
|
||||||
|
tic = time.perf_counter()
|
||||||
|
states = asyncio.run(run_async_evaluation())
|
||||||
|
latency = time.perf_counter() - tic
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
preds = [get_answer_value(state) for state in states]
|
||||||
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"invalid_rate": invalid_rate,
|
||||||
|
"latency": latency,
|
||||||
|
"questions_per_second": num_questions / latency,
|
||||||
|
"num_questions": num_questions,
|
||||||
|
"num_shots": num_shots,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="GSM8K evaluation for vLLM serve")
|
||||||
|
parser.add_argument("--num-shots",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of few-shot examples")
|
||||||
|
parser.add_argument("--num-questions",
|
||||||
|
type=int,
|
||||||
|
default=1319,
|
||||||
|
help="Number of questions to evaluate")
|
||||||
|
parser.add_argument("--max-tokens",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Max tokens for generation")
|
||||||
|
parser.add_argument("--host",
|
||||||
|
type=str,
|
||||||
|
default="http://127.0.0.1",
|
||||||
|
help="Host URL")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
||||||
|
parser.add_argument("--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Temperature for generation")
|
||||||
|
parser.add_argument("--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="Random seed for reproducibility")
|
||||||
|
parser.add_argument("--save-results",
|
||||||
|
type=str,
|
||||||
|
help="Save results to JSON file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
result = evaluate_gsm8k(
|
||||||
|
num_questions=args.num_questions,
|
||||||
|
num_shots=args.num_shots,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
temperature=args.temperature,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print results to terminal
|
||||||
|
print("\nResults:")
|
||||||
|
print(f"Accuracy: {result['accuracy']:.3f}")
|
||||||
|
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||||
|
print(f"Total latency: {result['latency']:.3f} s")
|
||||||
|
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||||
|
|
||||||
|
# Optional file saving
|
||||||
|
if args.save_results:
|
||||||
|
with open(args.save_results, "w") as f:
|
||||||
|
json.dump(result, f, indent=2)
|
||||||
|
print(f"Results saved to {args.save_results}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
90
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
90
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
GSM8K evaluation using vLLM server and isolated GSM8K script.
|
||||||
|
Replacement for lm-eval-harness with better performance and control.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
pytest -s -v test_gsm8k_correctness.py \
|
||||||
|
--config-list-file=configs/models-small.txt \
|
||||||
|
--tp-size=1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
from .gsm8k_eval import evaluate_gsm8k
|
||||||
|
|
||||||
|
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
||||||
|
|
||||||
|
|
||||||
|
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||||
|
"""Launch GSM8K evaluation using our isolated script."""
|
||||||
|
# Extract host and port from server URL
|
||||||
|
if "://" in server_url:
|
||||||
|
server_url = server_url.split("://")[1]
|
||||||
|
|
||||||
|
host_port = server_url.split("/")[0] # Remove path if present
|
||||||
|
if ":" in host_port:
|
||||||
|
host, port = host_port.split(":")
|
||||||
|
port = int(port)
|
||||||
|
else:
|
||||||
|
host = host_port
|
||||||
|
port = 8000
|
||||||
|
|
||||||
|
# Add http:// prefix if not present
|
||||||
|
if not host.startswith("http"):
|
||||||
|
host = f"http://{host}"
|
||||||
|
|
||||||
|
# Run GSM8K evaluation
|
||||||
|
results = evaluate_gsm8k(
|
||||||
|
num_questions=eval_config["num_questions"],
|
||||||
|
num_shots=eval_config["num_fewshot"],
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_gsm8k_correctness_param(config_filename, tp_size):
|
||||||
|
"""Test GSM8K correctness for a given model configuration."""
|
||||||
|
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
# Server arguments
|
||||||
|
server_args = [
|
||||||
|
"--max-model-len",
|
||||||
|
str(eval_config.get("max_model_len", 4096)),
|
||||||
|
"--enforce-eager",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Launch server and run evaluation
|
||||||
|
with RemoteOpenAIServer(eval_config["model_name"],
|
||||||
|
server_args,
|
||||||
|
max_wait_seconds=480) as remote_server:
|
||||||
|
server_url = remote_server.url_for("v1")
|
||||||
|
|
||||||
|
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
||||||
|
|
||||||
|
# Check accuracy against threshold
|
||||||
|
measured_accuracy = results["accuracy"]
|
||||||
|
expected_accuracy = eval_config["accuracy_threshold"]
|
||||||
|
|
||||||
|
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||||
|
print(f" Accuracy: {measured_accuracy:.3f}")
|
||||||
|
print(f" Expected: {expected_accuracy:.3f}")
|
||||||
|
print(f" Questions: {results['num_questions']}")
|
||||||
|
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||||
|
print(f" Latency: {results['latency']:.1f}s")
|
||||||
|
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||||
|
|
||||||
|
# Verify accuracy is within tolerance
|
||||||
|
assert measured_accuracy >= expected_accuracy - RTOL, (
|
||||||
|
f"Accuracy too low: {measured_accuracy:.3f} < "
|
||||||
|
f"{expected_accuracy:.3f} - {RTOL:.3f}")
|
||||||
|
|
||||||
|
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
Reference in New Issue
Block a user