mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
7 Commits
f6e7c200c0
...
grpo-vllm-
Author | SHA1 | Date | |
---|---|---|---|
0cc6a8a59d | |||
7ebd6fdcd6 | |||
6df27b811c | |||
af13f065a1 | |||
dcca9e9edd | |||
b54b021021 | |||
05d6d4e22b |
25
benchmark_vllm.slurm
Normal file
25
benchmark_vllm.slurm
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=trl-vllm-math
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=hopper-prod # Adjust this for your cluster
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --err=./logs/%x-%j.err
|
||||
#SBATCH --requeue
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
module load cuda/12.4
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
source trl-dev2/bin/activate
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
MODEL=$1
|
||||
USE_VLLM_LOGPROBS=$2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model $MODEL &
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --num_processes=7 run_grpo.py --model $MODEL --use_vllm_logprobs $USE_VLLM_LOGPROBS
|
124
run_grpo.py
Normal file
124
run_grpo.py
Normal file
@ -0,0 +1,124 @@
|
||||
# train_grpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
import argparse
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
import re
|
||||
|
||||
def accuracy_reward(completions, solution, **kwargs):
|
||||
"""Reward function that checks if the completion is the same as the ground truth."""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
try:
|
||||
reward = float(verify(answer_parsed, gold_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
|
||||
reward = 0.0
|
||||
else:
|
||||
# If the gold solution is not parseable, we reward 1 to skip this example
|
||||
reward = 1.0
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Train a GRPO model with custom reward function.")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct", help="Output directory for the trained model.")
|
||||
parser.add_argument("--logging_steps", type=int, default=1, help="Number of steps between logging.")
|
||||
parser.add_argument("--max_completion_length", type=int, default=4000, help="Maximum length of completions.")
|
||||
parser.add_argument("--num_iterations", type=int, default=2, help="Number of training iterations.")
|
||||
parser.add_argument("--use_vllm_logprobs", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Whether to use vLLM log probabilities.")
|
||||
parser.add_argument("--system_prompt", type=str, default="You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>", help="System prompt to prepend to the user prompt.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=f"{args.model}-vllm_logprobs-{args.use_vllm_logprobs}",
|
||||
gradient_accumulation_steps=32,
|
||||
per_device_train_batch_size=4,
|
||||
logging_steps=args.logging_steps,
|
||||
max_completion_length=args.max_completion_length,
|
||||
num_generations=14,
|
||||
use_vllm=True,
|
||||
use_vllm_logprobs=args.use_vllm_logprobs,
|
||||
num_iterations=args.num_iterations,
|
||||
reward_weights=[1.0, 0.1],
|
||||
gradient_checkpointing=True,
|
||||
log_completions=True,
|
||||
bf16=True,
|
||||
)
|
||||
|
||||
dataset = load_dataset("open-r1/OpenR1-Math-cn_k12-86k")
|
||||
|
||||
def make_conversation(example, prompt_column: str = "problem"):
|
||||
prompt = []
|
||||
|
||||
if args.system_prompt is not None:
|
||||
prompt.append({"role": "system", "content": args.system_prompt})
|
||||
|
||||
if prompt_column not in example:
|
||||
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
|
||||
|
||||
prompt.append({"role": "user", "content": example[prompt_column]})
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
for split in dataset:
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
reward_funcs = [accuracy_reward, format_reward]
|
||||
|
||||
model_kwargs = dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype="bfloat16",
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=args.model,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
)
|
||||
trainer.train()
|
@ -184,7 +184,8 @@ class VLLMClient:
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()["completion_ids"]
|
||||
response_json = response.json()
|
||||
return response_json["completion_ids"], response_json["log_probs"]
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
|
@ -299,6 +299,7 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
completion_ids: list[list[int]]
|
||||
log_probs: list[list[float]]
|
||||
|
||||
@app.post("/generate/", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
@ -312,6 +313,7 @@ def main(script_args: ScriptArguments):
|
||||
Returns:
|
||||
`GenerateResponse`:
|
||||
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
|
||||
- `log_probs` (list of list of `float`): A list of lists of float representing the token logprobs.
|
||||
|
||||
Example request:
|
||||
```json
|
||||
@ -321,6 +323,7 @@ def main(script_args: ScriptArguments):
|
||||
Example response:
|
||||
```json
|
||||
{"completion_ids": [[101, 102, 103], [201, 202, 203]]}
|
||||
{"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]}
|
||||
```
|
||||
"""
|
||||
|
||||
@ -340,10 +343,12 @@ def main(script_args: ScriptArguments):
|
||||
min_p=request.min_p,
|
||||
max_tokens=request.max_tokens,
|
||||
guided_decoding=guided_decoding,
|
||||
logprobs=1,
|
||||
)
|
||||
all_outputs = llm.generate(request.prompts, sampling_params=sampling_params)
|
||||
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
|
||||
return {"completion_ids": completion_ids}
|
||||
log_probs = [[list(p.values())[0].logprob for p in output.logprobs] for outputs in all_outputs for output in outputs.outputs]
|
||||
return {"completion_ids": completion_ids, "log_probs": log_probs}
|
||||
|
||||
class InitCommunicatorRequest(BaseModel):
|
||||
host: str
|
||||
|
@ -81,6 +81,8 @@ class GRPOConfig(TrainingArguments):
|
||||
use_vllm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
||||
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
||||
use_vllm_logprobs (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use vLLM's logprobs for the `"old_logprobs"` in the GRPO loss. Requires `use_vllm=True`.
|
||||
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
Host of the vLLM server to connect to.
|
||||
vllm_server_port (`int`, *optional*, defaults to `8000`):
|
||||
@ -231,6 +233,12 @@ class GRPOConfig(TrainingArguments):
|
||||
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
|
||||
},
|
||||
)
|
||||
use_vllm_logprobs: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use vLLM's logprobs for the 'old_logprobs' in the GRPO loss. Requires use_vllm=True."
|
||||
},
|
||||
)
|
||||
vllm_server_host: str = field(
|
||||
default="0.0.0.0",
|
||||
metadata={"help": "Host of the vLLM server to connect to."},
|
||||
|
@ -708,7 +708,7 @@ class GRPOTrainer(Trainer):
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
completion_ids = self.vllm_client.generate(
|
||||
completion_ids, vllm_log_probs = self.vllm_client.generate(
|
||||
prompts=ordered_set_of_prompts,
|
||||
n=self.num_generations,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
@ -721,18 +721,23 @@ class GRPOTrainer(Trainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text)
|
||||
vllm_log_probs = [None] * len(all_prompts_text)
|
||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
||||
# corresponding slice.
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
vllm_log_probs = broadcast_object_list(vllm_log_probs, from_process=0)
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
vllm_log_probs = vllm_log_probs[process_slice]
|
||||
|
||||
# Pad the completions, and concatenate them with the prompts
|
||||
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
||||
vllm_log_probs = [torch.tensor(logp, device=device) for logp in vllm_log_probs]
|
||||
vllm_log_probs = pad(vllm_log_probs, padding_value=self.processing_class.pad_token_id)
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
else:
|
||||
# Regular generation path
|
||||
@ -764,9 +769,12 @@ class GRPOTrainer(Trainer):
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
if self.args.use_vllm_logprobs:
|
||||
old_per_token_logps = vllm_log_probs
|
||||
else:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
|
||||
|
Reference in New Issue
Block a user