Compare commits

...

7 Commits

Author SHA1 Message Date
0cc6a8a59d add scripts for future reference 2025-04-02 08:06:50 +00:00
7ebd6fdcd6 remove print 2025-04-01 13:45:09 +00:00
6df27b811c fix a few bugs 2025-03-31 20:42:20 +00:00
af13f065a1 fix vllm lopprob bug 2025-03-31 13:56:07 +00:00
dcca9e9edd precommit 2025-03-31 12:37:18 +00:00
b54b021021 expose vlm logprobs 2025-03-31 12:34:46 +00:00
05d6d4e22b expose vllm logprobs 2025-03-31 12:33:19 +00:00
6 changed files with 177 additions and 6 deletions

25
benchmark_vllm.slurm Normal file
View File

@ -0,0 +1,25 @@
#!/bin/bash
#SBATCH --job-name=trl-vllm-math
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod # Adjust this for your cluster
#SBATCH --output=./logs/%x-%j.out
#SBATCH --err=./logs/%x-%j.err
#SBATCH --requeue
# Specific configuration optimized for the Hugging Face Compute Cluster
module load cuda/12.4
set -x -e
source ~/.bashrc
source trl-dev2/bin/activate
echo "START TIME: $(date)"
MODEL=$1
USE_VLLM_LOGPROBS=$2
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model $MODEL &
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --num_processes=7 run_grpo.py --model $MODEL --use_vllm_logprobs $USE_VLLM_LOGPROBS

124
run_grpo.py Normal file
View File

@ -0,0 +1,124 @@
# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import argparse
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
import re
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
reward = float(verify(answer_parsed, gold_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
reward = 0.0
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a GRPO model with custom reward function.")
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct", help="Output directory for the trained model.")
parser.add_argument("--logging_steps", type=int, default=1, help="Number of steps between logging.")
parser.add_argument("--max_completion_length", type=int, default=4000, help="Maximum length of completions.")
parser.add_argument("--num_iterations", type=int, default=2, help="Number of training iterations.")
parser.add_argument("--use_vllm_logprobs", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Whether to use vLLM log probabilities.")
parser.add_argument("--system_prompt", type=str, default="You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>", help="System prompt to prepend to the user prompt.")
args = parser.parse_args()
training_args = GRPOConfig(
output_dir=f"{args.model}-vllm_logprobs-{args.use_vllm_logprobs}",
gradient_accumulation_steps=32,
per_device_train_batch_size=4,
logging_steps=args.logging_steps,
max_completion_length=args.max_completion_length,
num_generations=14,
use_vllm=True,
use_vllm_logprobs=args.use_vllm_logprobs,
num_iterations=args.num_iterations,
reward_weights=[1.0, 0.1],
gradient_checkpointing=True,
log_completions=True,
bf16=True,
)
dataset = load_dataset("open-r1/OpenR1-Math-cn_k12-86k")
def make_conversation(example, prompt_column: str = "problem"):
prompt = []
if args.system_prompt is not None:
prompt.append({"role": "system", "content": args.system_prompt})
if prompt_column not in example:
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
prompt.append({"role": "user", "content": example[prompt_column]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
reward_funcs = [accuracy_reward, format_reward]
model_kwargs = dict(
attn_implementation="flash_attention_2",
torch_dtype="bfloat16",
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
trainer = GRPOTrainer(
model=args.model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset["train"],
)
trainer.train()

View File

@ -184,7 +184,8 @@ class VLLMClient:
},
)
if response.status_code == 200:
return response.json()["completion_ids"]
response_json = response.json()
return response_json["completion_ids"], response_json["log_probs"]
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

View File

@ -299,6 +299,7 @@ def main(script_args: ScriptArguments):
class GenerateResponse(BaseModel):
completion_ids: list[list[int]]
log_probs: list[list[float]]
@app.post("/generate/", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
@ -312,6 +313,7 @@ def main(script_args: ScriptArguments):
Returns:
`GenerateResponse`:
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
- `log_probs` (list of list of `float`): A list of lists of float representing the token logprobs.
Example request:
```json
@ -321,6 +323,7 @@ def main(script_args: ScriptArguments):
Example response:
```json
{"completion_ids": [[101, 102, 103], [201, 202, 203]]}
{"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]}
```
"""
@ -340,10 +343,12 @@ def main(script_args: ScriptArguments):
min_p=request.min_p,
max_tokens=request.max_tokens,
guided_decoding=guided_decoding,
logprobs=1,
)
all_outputs = llm.generate(request.prompts, sampling_params=sampling_params)
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
return {"completion_ids": completion_ids}
log_probs = [[list(p.values())[0].logprob for p in output.logprobs] for outputs in all_outputs for output in outputs.outputs]
return {"completion_ids": completion_ids, "log_probs": log_probs}
class InitCommunicatorRequest(BaseModel):
host: str

View File

@ -81,6 +81,8 @@ class GRPOConfig(TrainingArguments):
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
use_vllm_logprobs (`bool`, *optional*, defaults to `False`):
Whether to use vLLM's logprobs for the `"old_logprobs"` in the GRPO loss. Requires `use_vllm=True`.
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host of the vLLM server to connect to.
vllm_server_port (`int`, *optional*, defaults to `8000`):
@ -231,6 +233,12 @@ class GRPOConfig(TrainingArguments):
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
},
)
use_vllm_logprobs: bool = field(
default=False,
metadata={
"help": "Whether to use vLLM's logprobs for the 'old_logprobs' in the GRPO loss. Requires use_vllm=True."
},
)
vllm_server_host: str = field(
default="0.0.0.0",
metadata={"help": "Host of the vLLM server to connect to."},

View File

@ -708,7 +708,7 @@ class GRPOTrainer(Trainer):
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
completion_ids, vllm_log_probs = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
@ -721,18 +721,23 @@ class GRPOTrainer(Trainer):
)
else:
completion_ids = [None] * len(all_prompts_text)
vllm_log_probs = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
vllm_log_probs = broadcast_object_list(vllm_log_probs, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
vllm_log_probs = vllm_log_probs[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
vllm_log_probs = [torch.tensor(logp, device=device) for logp in vllm_log_probs]
vllm_log_probs = pad(vllm_log_probs, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
@ -764,9 +769,12 @@ class GRPOTrainer(Trainer):
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
if self.args.use_vllm_logprobs:
old_per_token_logps = vllm_log_probs
else:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None