mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
Compare commits
7 Commits
33eb86f54f
...
061535208c
Author | SHA1 | Date | |
---|---|---|---|
061535208c | |||
55f651c94d | |||
22d082f9a4 | |||
8ec9bf64a1 | |||
231d725f69 | |||
d69164e1cb | |||
2181d5b33a |
1
.github/workflows/model.yml
vendored
1
.github/workflows/model.yml
vendored
@ -208,6 +208,7 @@ jobs:
|
||||
|
||||
- name: Running mcore engine tests on 8 L20 GPUs
|
||||
run: |
|
||||
ray stop --force
|
||||
pytest -s -x tests/models/test_engine.py
|
||||
|
||||
cleanup:
|
||||
|
@ -6,21 +6,20 @@
|
||||
This is the official implementaion of paper [***Geometric-Mean Policy Optimization***](https://arxiv.org/abs/2507.20673).
|
||||
|
||||
<div align=center>
|
||||
<img width="3092" height="864" alt="image" src="https://github.com/user-attachments/assets/af4c7e0f-923a-45ef-9bcf-57109b8ee61e" />
|
||||
<img width="3092" height="864" alt="image" src="https://github.com/user-attachments/assets/20b04c4e-7ee8-4775-9af8-33c0158336e2" />
|
||||
</div>
|
||||
|
||||
|
||||
## 1. Contents
|
||||
- Geometric-Mean Policy Optimization
|
||||
- [1. Contents](#1-contents)
|
||||
- [2. Introduction](#2-introduction)
|
||||
- [3. Code Usage](#4-code-usage)
|
||||
- [4. Contacts](#5-contacts)
|
||||
- [5. Citation](#7-citation)
|
||||
- [3. Code Usage](#3-code-usage)
|
||||
- [4. Contacts](#4-contacts)
|
||||
- [5. Citation](#5-citation)
|
||||
|
||||
## 2. Introduction
|
||||
|
||||
Recent advancements, such as Group Relative Policy Optimization (GRPO), have enhanced the reasoning capabilities of large language models by optimizing the arithmetic mean of token-level rewards. However, GRPO suffers from unstable policy updates when processing tokens with outlier importance-weighted rewards, which manifests as extreme importance sampling ratios during training, i.e., the ratio between the sampling probabilities assigned to a token by the current and old policies. In this work, we propose Geometric-Mean Policy Optimization (GMPO), a stabilized variant of GRPO. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. In addition, we provide comprehensive theoretical and experimental analysis to justify the design and stability benefits of GMPO. Beyond improved stability, GMPO-7B outperforms GRPO by an average of 4.1% on multiple mathematical benchmarks and 1.4% on multimodal reasoning benchmark, including AIME24, AMC, MATH500, OlympiadBench, Minerva, and Geometry3K.
|
||||
Group Relative Policy Optimization (GRPO) has significantly enhanced the reasoning capability of large language models by optimizing the arithmetic mean of token-level rewards. Unfortunately, GRPO is observed to suffer from unstable policy updates when facing tokens with outlier importance-weighted rewards, which manifest as extreme importance sampling ratios during training. In this study, we propose Geometric-Mean Policy Optimization (GMPO), with the aim to improve the stability of GRPO through suppressing token reward outliers. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. GMPO is plug-and-play—simply replacing GRPO's arithmetic mean with the geometric mean of token-level rewards, as the latter is inherently less sensitive to outliers. GMPO is theoretically plausible—analysis reveals that both GMPO and GRPO are weighted forms of the policy gradient while the former enjoys more stable weights, which consequently benefits policy optimization and performance. Experiments on multiple mathematical reasoning benchmarks show that GMPO-7B improves the average Pass@1 of GRPO by up to 4.1%, outperforming many state-of-the-art approaches.
|
||||
|
||||
## 3. Code Usage
|
||||
|
||||
@ -30,7 +29,7 @@ clip_ratio_low=0.4
|
||||
clip_ratio_high=0.4
|
||||
loss_mode=geo_mean
|
||||
```
|
||||
|
||||
We observed that using a large clip ratio during Mixture-of-Experts (MoE) model training often leads to optimization instability. When training MoE models, consider lowering the clip ratio to achieve more stable convergence.
|
||||
To get started quickly, run:
|
||||
```
|
||||
bash examples/gmpo_trainer/run_qwen2_5-7b_math.sh
|
||||
@ -51,13 +50,10 @@ If you have any question about our work or this repository, please don't hesitat
|
||||
|
||||
## 5. Citation
|
||||
```
|
||||
@misc{zhao2025geometricmeanpolicyoptimization,
|
||||
title={Geometric-Mean Policy Optimization},
|
||||
author={Yuzhong Zhao and Yue Liu and Junpeng Liu and Jingye Chen and Xun Wu and Yaru Hao and Tengchao Lv and Shaohan Huang and Lei Cui and Qixiang Ye and Fang Wan and Furu Wei},
|
||||
year={2025},
|
||||
eprint={2507.20673},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2507.20673},
|
||||
@article{zhao2025geometric,
|
||||
title={Geometric-mean policy optimization},
|
||||
author={Zhao, Yuzhong and Liu, Yue and Liu, Junpeng and Chen, Jingye and Wu, Xun and Hao, Yaru and Lv, Tengchao and Huang, Shaohan and Cui, Lei and Ye, Qixiang and others},
|
||||
journal={arXiv preprint arXiv:2507.20673},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
@ -132,7 +132,7 @@ class RayDAPOTrainer(RayPPOTrainer):
|
||||
batch_keys=["input_ids", "attention_mask", "position_ids"],
|
||||
non_tensor_batch_keys=["raw_prompt_ids"],
|
||||
)
|
||||
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=False)
|
||||
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
|
||||
is_last_step = self.global_steps >= self.total_training_steps
|
||||
|
||||
@ -163,7 +163,7 @@ class RayDAPOTrainer(RayPPOTrainer):
|
||||
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
|
||||
)
|
||||
# repeat to align with repeated responses in rollout
|
||||
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=False)
|
||||
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
new_batch = new_batch.union(gen_batch_output)
|
||||
|
||||
with marked_timer("reward", timing_raw, "yellow"):
|
||||
|
@ -43,6 +43,20 @@ logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
def format_tool_response_manually(tool_message: dict, tool_call_name: str) -> str:
|
||||
"""Manually format tool response without using tokenizer template.
|
||||
|
||||
Args:
|
||||
tool_message: Tool message dictionary with 'content' field
|
||||
tool_call_name: Name of the tool that was called
|
||||
|
||||
Returns:
|
||||
Formatted tool response string
|
||||
"""
|
||||
content = tool_message["content"]
|
||||
return f"<|start|>functions.{tool_call_name} to=assistant<|channel|>commentary<|message|>{content}<|end|>"
|
||||
|
||||
|
||||
class MaxTokenExceededError(Exception):
|
||||
"""Indicate that history chat messages + tool message exceeds LLM max_tokens."""
|
||||
|
||||
@ -202,13 +216,39 @@ class ChatModel(BaseChatModel):
|
||||
|
||||
# encode tool response
|
||||
tool_responses = convert_to_openai_messages(messages[i + 1 :])
|
||||
tool_response_ids = await loop.run_in_executor(
|
||||
None,
|
||||
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True
|
||||
),
|
||||
)
|
||||
tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
|
||||
if self.tool_parser == "hermes":
|
||||
tool_response_ids = await loop.run_in_executor(
|
||||
None,
|
||||
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True
|
||||
),
|
||||
)
|
||||
tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
|
||||
elif self.tool_parser == "gpt-oss":
|
||||
# Format tool responses manually
|
||||
# since gpt-oss chat template requires tool call messages to parse tool response messages
|
||||
# we need to format the tool response messages manually
|
||||
tool_response_texts = []
|
||||
for tool_msg in tool_responses:
|
||||
if tool_msg["role"] == "tool":
|
||||
# Use tool message's name if available (for multiple tool calls)
|
||||
actual_tool_name = tool_msg.get("name", "unknown")
|
||||
if actual_tool_name == "unknown":
|
||||
logger.error(f"actual_tool_name: {actual_tool_name}")
|
||||
formatted = format_tool_response_manually(tool_msg, actual_tool_name)
|
||||
tool_response_texts.append(formatted)
|
||||
# need to add generation tokens for gpt-oss manually since add_generation_prompt is True
|
||||
tool_response_texts.append("<|start|>assistant")
|
||||
|
||||
# Tokenize the manually formatted tool responses
|
||||
tool_response_text = "".join(tool_response_texts)
|
||||
print(f"tool_response_text: {tool_response_text}")
|
||||
|
||||
tool_response_ids = await loop.run_in_executor(
|
||||
None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool parser: {self.tool_parser}")
|
||||
|
||||
# stop generation if response length exceeds max response length
|
||||
if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
|
||||
|
143
recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh
Normal file
143
recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh
Normal file
@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env bash
|
||||
#SBATCH --job-name=rl-langgraph-3B
|
||||
#SBATCH --partition=main
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=64
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH --mem=0
|
||||
#SBATCH --time=10:00:00
|
||||
#SBATCH --output=%x_%j.out
|
||||
#SBATCH --error=%x_%j.err
|
||||
|
||||
set -xeuo pipefail
|
||||
|
||||
# ================= cluster topology =================
|
||||
export GPUS_PER_NODE=${SLURM_GPUS_ON_NODE:-${GPUS_PER_NODE:-2}} # GPUs on this node
|
||||
NNODES=${SLURM_JOB_NUM_NODES:-${NNODES:-1}}
|
||||
export NNODES
|
||||
export RAY_NUM_NODES=$NNODES
|
||||
|
||||
# Require at least 2 GPUs
|
||||
TOTAL_GPUS=$((GPUS_PER_NODE * NNODES))
|
||||
if [ "$TOTAL_GPUS" -lt 2 ]; then
|
||||
echo "Error: at least 2 GPUs are required, detected $TOTAL_GPUS." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Using $NNODES nodes and $GPUS_PER_NODE GPUs per node..."
|
||||
|
||||
# ================= data/model/tool =================
|
||||
HDFS_ROOT=${HDFS_ROOT:-$PWD}
|
||||
DATA_ROOT=${DATA_ROOT:-$PWD}
|
||||
|
||||
# Prefer local model if present, otherwise fall back to HF hub path
|
||||
model_path="lmsys/gpt-oss-20b-bf16"
|
||||
|
||||
# Use the default output directory produced by create_dataset.py
|
||||
train_files=$DATA_ROOT/data/math_expression_tool/train.parquet
|
||||
test_files=$DATA_ROOT/data/math_expression_tool/test.parquet
|
||||
|
||||
# Agent config
|
||||
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml
|
||||
|
||||
# =================== wandb ===================
|
||||
project_name=math_expression_tool
|
||||
experiment_name=gpt-oss-20b-bf16
|
||||
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
|
||||
|
||||
# ================= algorithm =================
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=false
|
||||
kl_coef=0.0
|
||||
use_kl_loss=false
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
max_turns=8
|
||||
max_prompt_length=1024
|
||||
max_response_length=8192
|
||||
actor_lr=1e-6
|
||||
|
||||
train_batch_size=128
|
||||
ppo_mini_batch_size=16
|
||||
n_resp_per_prompt=8
|
||||
n_resp_per_prompt_val=1
|
||||
|
||||
# =================== logging ===================
|
||||
export RAY_LOGGING_LEVEL=DEBUG
|
||||
export HYDRA_FULL_ERROR=1
|
||||
|
||||
# ================= performance =================
|
||||
export NCCL_IBEXT_DISABLE=1
|
||||
export NCCL_NVLS_ENABLE=1
|
||||
export NCCL_IB_HCA=mlx5
|
||||
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
|
||||
|
||||
infer_tp=2 # vLLM tensor parallel size
|
||||
train_sp=4 # Ulysses sequence parallel size for actor
|
||||
offload=true
|
||||
|
||||
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
|
||||
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))
|
||||
|
||||
train_files="['$train_files']"
|
||||
test_files="['$test_files']"
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=$adv_estimator \
|
||||
algorithm.use_kl_in_reward=$use_kl_in_reward \
|
||||
algorithm.kl_ctrl.kl_coef=$kl_coef \
|
||||
data.train_files="$train_files" \
|
||||
data.val_files="$test_files" \
|
||||
data.return_raw_chat=true \
|
||||
data.train_batch_size=$train_batch_size \
|
||||
data.max_prompt_length=$max_prompt_length \
|
||||
data.max_response_length=$max_response_length \
|
||||
data.filter_overlong_prompts=true \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path="$model_path" \
|
||||
actor_rollout_ref.model.use_remove_padding=true \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=true \
|
||||
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
|
||||
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
|
||||
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
|
||||
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.actor.optim.lr=$actor_lr \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=true \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.mode=async \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
|
||||
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
|
||||
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
|
||||
actor_rollout_ref.rollout.multi_turn.format=gpt-oss \
|
||||
actor_rollout_ref.rollout.agent.tool_parser=gpt-oss \
|
||||
actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
||||
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=1.0\
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
|
||||
trainer.logger='["console","wandb"]' \
|
||||
trainer.project_name=$project_name \
|
||||
trainer.experiment_name=$experiment_name \
|
||||
trainer.n_gpus_per_node="$GPUS_PER_NODE" \
|
||||
trainer.val_before_train=true \
|
||||
trainer.log_val_generations=50 \
|
||||
trainer.nnodes="$NNODES" \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="$default_local_dir" \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=1 "$@"
|
55
recipe/open_math_reasoning/README.md
Normal file
55
recipe/open_math_reasoning/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# Open math reasoning
|
||||
## Introduction
|
||||
In this recipe, we perform SFT on the [open math reasoning](https://huggingface.co/datasets/nvidia/OpenMathReasoning) dataset using the new SFT trainer with backend agostic model engine. Note that our goal is not to replicate the [AIMO-2 Winning Solution](https://arxiv.org/abs/2504.16891) work, but to demonstrate a SFT demo from end to end.
|
||||
|
||||
Note that you may need to modify the path as needed in the following scripts.
|
||||
## Dataset Preprocessing
|
||||
### Download Dataset
|
||||
```bash
|
||||
hf download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* --local-dir /path/to/dataset/nvidia/OpenMathReasoning
|
||||
hf download math-ai/aime24 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime24
|
||||
hf download math-ai/aime25 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime25
|
||||
```
|
||||
|
||||
### Preprocess the dataset
|
||||
```bash
|
||||
python3 recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py --local_dataset_path /path/to/nvidia/OpenMathReasoning --local_save_dir /path/to/open_math_reasoning
|
||||
```
|
||||
|
||||
### Prepare the eval dataset
|
||||
```bash
|
||||
python3 recipe/open_math_reasoning/prepare_eval_dataset.py --local_dataset_path /path/to/dataset --local_save_dir /path/to/eval_dataset
|
||||
```
|
||||
|
||||
## Train the model using SFT
|
||||
### FSDP backend
|
||||
export CKPT_HOME=/path/to/ckpt
|
||||
export BACKEND=fsdp2
|
||||
export MODEL_ID=Qwen/Qwen3-8B-Base
|
||||
export TRAIN_FILES=/path/to/open_math_reasoning/cot_dataset.parquet
|
||||
bash recipe/open_math_reasoning/run_sft_qwen3_8b.sh
|
||||
|
||||
### Megatron backend
|
||||
TODO
|
||||
|
||||
## Eval the model
|
||||
### Merge checkpoint into huggingface format
|
||||
```bash
|
||||
python -m verl.model_merger merge --backend fsdp --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface
|
||||
```
|
||||
|
||||
### Generate the responses
|
||||
```bash
|
||||
export MODEL_PATH=/path/to/ckpt/global_step_19751/huggingface
|
||||
bash recipe/open_math_reasoning/run_generation.sh
|
||||
```
|
||||
|
||||
### Evaluate the responses
|
||||
```bash
|
||||
bash recipe/open_math_reasoning/run_eval.sh
|
||||
```
|
||||
|
||||
You should see the results like:
|
||||
```python
|
||||
{'test_score/aime24': 0.584375, 'test_score/aime25': 0.43333333333333335}
|
||||
```
|
22
recipe/open_math_reasoning/compute_score.py
Normal file
22
recipe/open_math_reasoning/compute_score.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def compute_score_data_source(data_source, response, ground_truth):
|
||||
from verl.utils.reward_score.math_reward import compute_score
|
||||
|
||||
if data_source in ["aime24", "aime25"]:
|
||||
return compute_score(response, ground_truth)
|
||||
else:
|
||||
raise ValueError(f"Unknown data source: {data_source}")
|
96
recipe/open_math_reasoning/prepare_eval_dataset.py
Normal file
96
recipe/open_math_reasoning/prepare_eval_dataset.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# prepare eval dataset including AIME'24, AIME'25
|
||||
|
||||
# hf download math-ai/aime24 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime24
|
||||
# hf download math-ai/aime25 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime25
|
||||
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
||||
from verl.utils.reward_score.math_reward import remove_boxed
|
||||
|
||||
instruction_following = "Please reason step by step, and put your final answer within \\boxed{}."
|
||||
|
||||
|
||||
def make_map_fn(data_source):
|
||||
def process_fn(example, idx):
|
||||
question_raw = example.pop("problem")
|
||||
|
||||
question = question_raw + " " + instruction_following
|
||||
|
||||
if "solution" not in example:
|
||||
example["solution"] = example["answer"]
|
||||
|
||||
answer_raw = example.pop("solution")
|
||||
|
||||
example.clear()
|
||||
|
||||
try:
|
||||
solution = remove_boxed(answer_raw)
|
||||
except Exception:
|
||||
solution = answer_raw
|
||||
|
||||
data = {
|
||||
"data_source": data_source,
|
||||
"prompt": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
}
|
||||
],
|
||||
"ability": "math",
|
||||
"reward_model": {"style": "rule", "ground_truth": solution},
|
||||
"extra_info": {
|
||||
"index": idx,
|
||||
"answer": answer_raw,
|
||||
"question": question_raw,
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
|
||||
parser.add_argument(
|
||||
"--local_save_dir", default="~/data/math-ai", help="The save directory for the preprocessed dataset."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_dataset_path is not None:
|
||||
aime24_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime24")
|
||||
aime25_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime25")
|
||||
else:
|
||||
aime24_dataset_path = "math-ai/aime24"
|
||||
aime25_dataset_path = "math-ai/aime25"
|
||||
|
||||
aime24_dataset = datasets.load_dataset(aime24_dataset_path, split="test")
|
||||
aime25_dataset = datasets.load_dataset(aime25_dataset_path, split="test")
|
||||
|
||||
aime24_dataset = aime24_dataset.map(function=make_map_fn("aime24"), with_indices=True)
|
||||
aime25_dataset = aime25_dataset.map(function=make_map_fn("aime25"), with_indices=True)
|
||||
|
||||
local_save_dir = os.path.expanduser(args.local_save_dir)
|
||||
os.makedirs(local_save_dir, exist_ok=True)
|
||||
|
||||
aime24_dataset.to_parquet(os.path.join(local_save_dir, "aime24_test.parquet"))
|
||||
aime25_dataset.to_parquet(os.path.join(local_save_dir, "aime25_test.parquet"))
|
@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
|
||||
--local-dir /path/to/nvidia/OpenMathReasoning
|
||||
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
|
||||
--local-dir /opt/tiger/nvidia/OpenMathReasoning
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
|
||||
parser.add_argument(
|
||||
"--local_save_dir",
|
||||
default="~/data/open_math_reasoning",
|
||||
help="The save directory for the preprocessed dataset.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
local_dataset_path = args.local_dataset_path
|
||||
|
||||
data_source = "nvidia/OpenMathReasoning"
|
||||
|
||||
if local_dataset_path is not None:
|
||||
dataset = datasets.load_dataset(local_dataset_path, split="cot")
|
||||
else:
|
||||
dataset = datasets.load_dataset(data_source, split="cot")
|
||||
|
||||
def make_map_fn(split):
|
||||
def process_fn(example, idx):
|
||||
question = example.pop("problem")
|
||||
solution = example.pop("generated_solution")
|
||||
|
||||
extra_info = {}
|
||||
for key, value in example.items():
|
||||
extra_info[key] = value
|
||||
example.clear()
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "user", "content": question, "loss_mask": 0},
|
||||
{"role": "assistant", "content": solution, "loss_mask": 1},
|
||||
],
|
||||
"extra_info": extra_info,
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
# filter out data where the problem_type is not has_answer_extracted
|
||||
dataset = dataset.filter(lambda example: example["problem_type"] == "has_answer_extracted")
|
||||
dataset = dataset.map(function=make_map_fn("cot"), with_indices=True)
|
||||
local_save_dir = os.path.expanduser(args.local_save_dir)
|
||||
os.makedirs(local_save_dir, exist_ok=True)
|
||||
dataset.to_parquet(os.path.join(local_save_dir, "cot_dataset.parquet"))
|
7
recipe/open_math_reasoning/run_eval.sh
Normal file
7
recipe/open_math_reasoning/run_eval.sh
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Evaluation
|
||||
python3 -m verl.trainer.main_eval \
|
||||
data.path=$HOME/data/gen/qwen_8b_gen_test.parquet \
|
||||
custom_reward_function.path=recipe/open_math_reasoning/compute_score.py \
|
||||
custom_reward_function.name=compute_score_data_source
|
32
recipe/open_math_reasoning/run_generation.sh
Normal file
32
recipe/open_math_reasoning/run_generation.sh
Normal file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
MODEL_PATH=${MODEL_PATH:-/path/to/ckpt/global_step_19751/huggingface}
|
||||
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
NNODES=${NNODES:-1}
|
||||
OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_8b_gen_test.parquet}
|
||||
GEN_TP=${GEN_TP:-1} # Default tensor parallel size to 2
|
||||
|
||||
aime24_test_path=${HOME}/data/math-ai/aime24_test.parquet
|
||||
aime25_test_path=${HOME}/data/math-ai/aime25_test.parquet
|
||||
train_files="['$aime24_test_path', '$aime25_test_path']"
|
||||
|
||||
python3 -m verl.trainer.main_generation_server \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.model.trust_remote_code=True \
|
||||
actor_rollout_ref.rollout.temperature=1.0 \
|
||||
actor_rollout_ref.rollout.top_p=0.7 \
|
||||
actor_rollout_ref.rollout.prompt_length=2048 \
|
||||
actor_rollout_ref.rollout.response_length=20480 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.n=32 \
|
||||
data.train_files="$train_files" \
|
||||
data.prompt_key=prompt \
|
||||
+data.output_path="${OUTPUT_PATH}" \
|
||||
|
||||
|
||||
|
94
recipe/open_math_reasoning/run_sft_qwen3_8b.sh
Normal file
94
recipe/open_math_reasoning/run_sft_qwen3_8b.sh
Normal file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
|
||||
|
||||
TRAIN_FILES=${TRAIN_FILES:-/path/to/cot_dataset.parquet}
|
||||
|
||||
backend=${BACKEND:-fsdp}
|
||||
|
||||
project_name=verl_sft_test
|
||||
|
||||
RESUME_MODE=auto
|
||||
MODEL_ID=${MODEL_ID:-Qwen/Qwen3-8B-Base}
|
||||
|
||||
SP_SIZE=${SP_SIZE:-8}
|
||||
FSDP_SIZE=${FSDP_SIZE:-16}
|
||||
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"}
|
||||
|
||||
TP_SIZE=${TP_SIZE:-1}
|
||||
PP_SIZE=${PP_SIZE:-1}
|
||||
VPP_SIZE=${VPP_SIZE:-null}
|
||||
CP_SIZE=${CP_SIZE:-1}
|
||||
|
||||
PAD_MODE=${PAD_MODE:-no_padding}
|
||||
|
||||
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
|
||||
|
||||
FSDP_ENGINE_CONFIG="\
|
||||
engine=${backend} \
|
||||
optim=${backend} \
|
||||
optim.lr=2e-5 \
|
||||
optim.lr_warmup_steps_ratio=0.01 \
|
||||
optim.weight_decay=0.1 \
|
||||
optim.betas="[0.9,0.95]" \
|
||||
optim.clip_grad=1.0 \
|
||||
optim.min_lr_ratio=0.1 \
|
||||
optim.warmup_style=cosine \
|
||||
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
|
||||
engine.strategy=${FSDP_STRATEGY} \
|
||||
engine.fsdp_size=${FSDP_SIZE}"
|
||||
|
||||
|
||||
MEGATRON_ENGINE_CONFIG="\
|
||||
engine=${backend} \
|
||||
optim=${backend} \
|
||||
optim.lr=1e-5 \
|
||||
optim.lr_warmup_steps_ratio=0.2 \
|
||||
optim.weight_decay=0.1 \
|
||||
optim.betas="[0.9,0.95]" \
|
||||
optim.clip_grad=1.0 \
|
||||
optim.lr_warmup_init=0 \
|
||||
optim.lr_decay_style=cosine \
|
||||
optim.min_lr=1e-6 \
|
||||
engine.tensor_model_parallel_size=${TP_SIZE} \
|
||||
engine.pipeline_model_parallel_size=${PP_SIZE} \
|
||||
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
|
||||
engine.context_parallel_size=${CP_SIZE}"
|
||||
|
||||
if [ "$backend" = "fsdp" ]; then
|
||||
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
|
||||
echo "Using fsdp engine"
|
||||
exp_name=nvidia-openmathreasoning-qwen3-8b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1008a1
|
||||
else
|
||||
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
|
||||
echo "Using megatron engine"
|
||||
exp_name=nvidia-openmathreasoning-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
|
||||
fi
|
||||
|
||||
CKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}}
|
||||
mkdir -p "${CKPT_HOME}"
|
||||
|
||||
torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \
|
||||
${ENTRYPOINT} \
|
||||
data.train_files="${TRAIN_FILES}" \
|
||||
data.train_batch_size=96 \
|
||||
data.max_length=32768 \
|
||||
data.pad_mode=${PAD_MODE} \
|
||||
data.truncation=error \
|
||||
data.use_dynamic_bsz=True \
|
||||
data.max_token_len_per_gpu=65536 \
|
||||
data.messages_key=messages \
|
||||
model.path=$MODEL_ID \
|
||||
model.use_remove_padding=${USE_REMOVE_PADDING} \
|
||||
${ENGINE_CONFIG} \
|
||||
trainer.test_freq=-1 \
|
||||
trainer.save_freq=4000 \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.total_epochs=1 \
|
||||
trainer.default_local_dir="${CKPT_HOME}" \
|
||||
trainer.resume_mode=${RESUME_MODE} \
|
||||
trainer.max_ckpt_to_keep=5 \
|
||||
checkpoint.save_contents=[model,optimizer,extra]
|
@ -24,7 +24,7 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig
|
||||
|
||||
from verl import DataProto
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
||||
@ -289,8 +289,9 @@ def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, mod
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
ref_model_config = AutoConfig.from_pretrained(model_path)
|
||||
with torch.device("meta"):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
ref_model = AutoModelForCausalLM.from_config(ref_model_config)
|
||||
|
||||
from verl.workers.engine import BaseEngine, EngineRegistry
|
||||
|
||||
|
@ -18,7 +18,7 @@ data:
|
||||
max_token_len_per_gpu: 8192
|
||||
use_dynamic_bsz: True
|
||||
train_files: ~/data/gsm8k/train.parquet
|
||||
val_files: ~/data/gsm8k/test.parquet
|
||||
val_files: null
|
||||
# Multi-turn settings
|
||||
messages_key: messages # Key for messages list in multi-turn mode
|
||||
tools_key: tools # Key for tools list in multi-turn mode
|
||||
|
@ -31,7 +31,8 @@ from verl.utils.fs import copy_to_local
|
||||
|
||||
|
||||
@ray.remote
|
||||
def process_item(reward_fn, data_source, response_lst, reward_data):
|
||||
def process_item(config, data_source, response_lst, reward_data):
|
||||
reward_fn = get_custom_reward_fn(config)
|
||||
ground_truth = reward_data["ground_truth"]
|
||||
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
|
||||
return data_source, np.mean(score_lst)
|
||||
@ -53,11 +54,9 @@ def main(config):
|
||||
|
||||
# evaluate test_score based on data source
|
||||
data_source_reward = defaultdict(list)
|
||||
compute_score = get_custom_reward_fn(config)
|
||||
|
||||
# Create remote tasks
|
||||
remote_tasks = [
|
||||
process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
|
||||
process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
|
||||
]
|
||||
|
||||
# Process results as they come in
|
||||
|
@ -17,6 +17,7 @@ Generate responses given a dataset of prompts
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
import hydra
|
||||
import numpy as np
|
||||
import ray
|
||||
@ -30,31 +31,12 @@ from pprint import pprint
|
||||
|
||||
import pandas as pd
|
||||
from omegaconf import OmegaConf
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from verl.utils.hdfs_io import makedirs
|
||||
from verl.workers.rollout.replica import get_rollout_replica_class
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
run_generation(config)
|
||||
|
||||
|
||||
def run_generation(config) -> None:
|
||||
if not ray.is_initialized():
|
||||
# this is for local ray cluster
|
||||
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}}
|
||||
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
|
||||
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
|
||||
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
|
||||
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
|
||||
print(f"ray init kwargs: {ray_init_kwargs}")
|
||||
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
||||
|
||||
ray.get(main_task.remote(config))
|
||||
|
||||
|
||||
async def start_server(config):
|
||||
tp_size = config.actor_rollout_ref.rollout.tensor_model_parallel_size
|
||||
num_replicas = (config.trainer.n_gpus_per_node * config.trainer.nnodes) // tp_size
|
||||
@ -81,23 +63,42 @@ async def start_server(config):
|
||||
return server_handles, server_addresses
|
||||
|
||||
|
||||
async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):
|
||||
# here we should sample n_samples for each chat_lst
|
||||
client = AsyncOpenAI(
|
||||
api_key="123-abc",
|
||||
base_url=f"http://{server_address}/v1",
|
||||
)
|
||||
async def submit_request(server_address, **chat_complete_request):
|
||||
try:
|
||||
extra_headers = chat_complete_request.pop("extra_headers", {})
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
async with session.post(
|
||||
url=f"http://{server_address}/v1/chat/completions",
|
||||
headers={"Authorization": "Bearer token-abc123", **extra_headers},
|
||||
json=chat_complete_request,
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return ChatCompletion(**data)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
tasks = [
|
||||
client.chat.completions.create(
|
||||
model=model_path,
|
||||
messages=messages,
|
||||
|
||||
async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):
|
||||
# here we should sample n_samples for each chat_lst.
|
||||
# we use aiohttp to avoid hang in AsyncOpenAI when the number of requests is large.
|
||||
|
||||
# client = AsyncOpenAI(
|
||||
# api_key="123-abc",
|
||||
# base_url=f"http://{server_address}/v1",
|
||||
# )
|
||||
|
||||
chat_complete_request = [
|
||||
{
|
||||
"model": model_path,
|
||||
"messages": messages,
|
||||
**sampling_params,
|
||||
)
|
||||
}
|
||||
for messages in chat_lst
|
||||
for _ in range(n_samples)
|
||||
]
|
||||
|
||||
tasks = [submit_request(server_address, **req) for req in chat_complete_request]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
@ -118,8 +119,10 @@ async def generate(
|
||||
return results
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def main_task(config):
|
||||
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}})
|
||||
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
@ -136,8 +139,21 @@ def main_task(config):
|
||||
"max_tokens": config.actor_rollout_ref.rollout.response_length,
|
||||
}
|
||||
|
||||
from omegaconf import ListConfig
|
||||
|
||||
train_files = config.data.train_files
|
||||
if not isinstance(train_files, list | ListConfig):
|
||||
train_files = [train_files]
|
||||
|
||||
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
|
||||
dataset = pd.read_parquet(config.data.train_files)
|
||||
|
||||
datasets = []
|
||||
for train_file in train_files:
|
||||
dataset = pd.read_parquet(train_file)
|
||||
datasets.append(dataset)
|
||||
|
||||
# concat dataset
|
||||
dataset = pd.concat(datasets, axis=0, ignore_index=True)
|
||||
chat_lst = dataset[config.data.prompt_key].tolist()
|
||||
chat_lst = [chat.tolist() for chat in chat_lst]
|
||||
chat_numpy = np.array(chat_lst)
|
||||
@ -151,7 +167,6 @@ def main_task(config):
|
||||
)
|
||||
|
||||
# reshape results into a numpy array
|
||||
|
||||
import itertools
|
||||
|
||||
results = list(itertools.chain.from_iterable(gen_results))
|
||||
@ -170,6 +185,7 @@ def main_task(config):
|
||||
# write to a new parquet
|
||||
output_dir = os.path.dirname(config.data.output_path)
|
||||
makedirs(output_dir, exist_ok=True)
|
||||
print(f"Saving results to {config.data.output_path}")
|
||||
dataset.to_parquet(config.data.output_path)
|
||||
|
||||
|
||||
|
@ -146,7 +146,10 @@ class SFTTrainer:
|
||||
config = self.config
|
||||
tokenizer = self.model_config.tokenizer
|
||||
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
|
||||
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
||||
if config.data.val_files:
|
||||
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
||||
else:
|
||||
val_dataset = None
|
||||
|
||||
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
||||
|
||||
@ -181,19 +184,22 @@ class SFTTrainer:
|
||||
pin_memory_device=device_name,
|
||||
)
|
||||
|
||||
self.val_sampler = DistributedSampler(
|
||||
self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True
|
||||
)
|
||||
self.val_dataloader = StatefulDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=self.train_batch_size_per_dp,
|
||||
sampler=self.val_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
pin_memory_device=device_name,
|
||||
)
|
||||
if self.val_dataset:
|
||||
self.val_sampler = DistributedSampler(
|
||||
self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True
|
||||
)
|
||||
self.val_dataloader = StatefulDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=self.train_batch_size_per_dp,
|
||||
sampler=self.val_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
pin_memory_device=device_name,
|
||||
)
|
||||
else:
|
||||
self.val_dataloader = None
|
||||
|
||||
def fit(self):
|
||||
is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0
|
||||
@ -242,6 +248,7 @@ class SFTTrainer:
|
||||
}
|
||||
|
||||
train_time = 0
|
||||
total_tokens = 0
|
||||
for epoch in range(start_epoch, self.config.trainer.total_epochs):
|
||||
self.train_sampler.set_epoch(epoch=epoch)
|
||||
|
||||
@ -302,6 +309,8 @@ class SFTTrainer:
|
||||
metrics["train/grad_norm"] = metrics.pop("grad_norm")
|
||||
metrics["train/lr"] = lr
|
||||
metrics["train/global_tokens"] = output_tensor.sum().item()
|
||||
total_tokens += metrics["train/global_tokens"]
|
||||
metrics["train/total_tokens(B)"] = total_tokens / 1e9
|
||||
# mfu
|
||||
delta_time = timer.last
|
||||
estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time)
|
||||
@ -315,7 +324,7 @@ class SFTTrainer:
|
||||
is_save_step = global_step % self.save_freq == 0
|
||||
|
||||
# early exit or validation step
|
||||
if is_last_step or (self.test_freq > 0 and is_valid_step):
|
||||
if is_last_step and self.val_dataloader is not None or (self.test_freq > 0 and is_valid_step):
|
||||
# Perform validation
|
||||
val_losses = []
|
||||
for val_data in self.val_dataloader:
|
||||
|
@ -182,7 +182,8 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
|
||||
|
||||
tracker_file = get_checkpoint_tracker_filename(path)
|
||||
if not os.path.exists(tracker_file):
|
||||
print(f"Checkpoint tracker file does not exist: {tracker_file}")
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
print(f"Checkpoint tracker file does not exist: {tracker_file}")
|
||||
return None
|
||||
|
||||
with open(tracker_file, "rb") as f:
|
||||
|
@ -1 +1 @@
|
||||
0.5.0.dev
|
||||
0.7.0.dev
|
||||
|
Reference in New Issue
Block a user