mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[recipe] feat: Add InfiGUI-G1 recipe for MLLM GUI grounding (#3242)
### What does this PR do? This PR introduces a new recipe, `infigui-g1`, for training Multimodal Large Language Models (MLLMs) in GUI grounding tasks. This recipe implements a reinforcement learning approach that significantly improves the model's ability to understand and interact with graphical user interfaces. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: https://github.com/search?q=repo%3Avolcengine%2Fverl+gui&type=pullrequests - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test The effectiveness of this recipe has been validated through experiments. Key results are as follows: - The training curves for reward, validation accuracy, and exploration success rate all show a upward trend. - After 156 steps of training on sample data, the 3b model achieves a score of **41.2** on the `screenspot-pro` benchmark, a substantial improvement over the base model's score of **18.2**. <img width="345" height="291" alt="Screenshot 2025-08-27 172010" src="https://github.com/user-attachments/assets/9ecd93d5-4f9b-4c40-831c-79a50fd197c4" /> <img width="347" height="292" alt="Screenshot 2025-08-27 171902" src="https://github.com/user-attachments/assets/2e437c1f-9eb0-4106-a6c3-b22125026a79" /> <img width="346" height="293" alt="Screenshot 2025-08-27 171928" src="https://github.com/user-attachments/assets/9c94515d-1501-40f4-979c-95e2f819dc62" /> ### API and Usage Example The recipe is self-contained and can be run using the provided scripts. For example, to run training with the 3B parameter model: ```bash # In verl path bash recipe/infigui-g1/run_3b.sh ``` ### Design & Code Changes This PR adds a new, independent recipe located in `recipe/infigui-g1/`. The changes are fully encapsulated within this directory and do not affect any other part of the codebase. The new files include: - `recipe/infigui-g1/README.md`: An introduction to the recipe. - `recipe/infigui-g1/run_3b.sh`, `run_7b.sh`: Scripts to launch training. - `recipe/infigui-g1/reward_fn.py`: Custom reward function implementation. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
56
recipe/infigui-g1/README.md
Normal file
56
recipe/infigui-g1/README.md
Normal file
@ -0,0 +1,56 @@
|
||||
# Recipe for InfiGUI-G1
|
||||
|
||||
This directory contains the official implementation for the paper [InfiGUI-G1: Advancing GUI Grounding with Adaptive Exploration Policy Optimization](https://arxiv.org/abs/2508.05731).
|
||||
|
||||
This work introduces Adaptive Exploration Policy Optimization (AEPO), a policy optimization framework designed to enhance GUI grounding in Multimodal Large Language Models (MLLMs). AEPO improves exploration efficiency by employing a multi-answer generation strategy and a theoretically grounded Adaptive Exploration Reward (AER) function. This approach effectively addresses the challenge of semantic alignment in complex GUI grounding tasks.
|
||||
|
||||
We provide training scripts for both 3B and 7B models, configured for a single machine with 8 GPUs by default.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Please follow the main environment setup guide for `verl`.
|
||||
|
||||
The provided scripts use the following Docker image: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Before starting the training, you need to download the example dataset. This dataset is a filtered version of [omniact](https://huggingface.co/datasets/Writer/omniact), containing only grounding tasks and excluding easy samples.
|
||||
|
||||
The data is hosted on the Hugging Face. You can download it using the `huggingface-cli`:
|
||||
|
||||
```bash
|
||||
huggingface-cli download --repo-type dataset --resume-download InfiX-ai/omniact_grounding_filtered --local-dir data/omniact_grounding_filtered
|
||||
```
|
||||
|
||||
This command will download the training and validation parquet files into the `data/omniact_grounding_filtered` directory, which is the default path used by the scripts.
|
||||
|
||||
## Training
|
||||
|
||||
We provide scripts to train the 3B and 7B models. Please run them from the root directory of `verl`.
|
||||
|
||||
- **Train the 3B model:**
|
||||
|
||||
```bash
|
||||
bash recipe/infigui-g1/run_3b.sh
|
||||
```
|
||||
|
||||
- **Train the 7B model:**
|
||||
|
||||
```bash
|
||||
bash recipe/infigui-g1/run_7b.sh
|
||||
```
|
||||
|
||||
## Using Custom Data
|
||||
|
||||
If you wish to train on your own dataset, please format your data to match the structure of the example files located in `data/omniact_grounding_filtered`.
|
||||
|
||||
Once your data is ready, you need to update the data path arguments in the training script.
|
||||
|
||||
In `run_3b.sh` or `run_7b.sh`, modify the following lines:
|
||||
|
||||
```bash
|
||||
data.train_files=./path/to/your/train_data.parquet \
|
||||
data.val_files=./path/to/your/val_data.parquet \
|
||||
```
|
||||
|
||||
Replace the paths with the location of your custom data files.
|
388
recipe/infigui-g1/reward_fn.py
Normal file
388
recipe/infigui-g1/reward_fn.py
Normal file
@ -0,0 +1,388 @@
|
||||
# Copyright 2025 Individual Contributor: InfiX.ai
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from itertools import combinations
|
||||
|
||||
FMT_RATIO = 1.0
|
||||
ACC_RATIO = 1.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Utility Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def extract_think_format(predict_str: str) -> None | dict[str, str]:
|
||||
"""
|
||||
Check if the predicted string meets format requirements and extract thinking and answer parts.
|
||||
|
||||
Args:
|
||||
predict_str: The predicted string
|
||||
|
||||
Returns:
|
||||
If format requirements are met, returns a dictionary containing thinking and answer parts;
|
||||
otherwise returns None
|
||||
"""
|
||||
if not predict_str or not isinstance(predict_str, str):
|
||||
return None
|
||||
|
||||
# Check if <think> is at the beginning
|
||||
if not predict_str.startswith("<think>"):
|
||||
return None
|
||||
|
||||
# Check if there is <think>...</think> format
|
||||
pattern = r"<think>(.*?)</think>"
|
||||
think_match = re.search(pattern, predict_str, re.DOTALL)
|
||||
if not think_match:
|
||||
return None
|
||||
|
||||
if predict_str.count("<think>") != 1 or predict_str.count("</think>") != 1:
|
||||
return None
|
||||
|
||||
# Extract thinking content
|
||||
think_content = think_match.group(1).strip()
|
||||
if not think_content:
|
||||
return None
|
||||
|
||||
# Get content after </think>
|
||||
think_end_pos = predict_str.find("</think>") + len("</think>")
|
||||
post_think_content = predict_str[think_end_pos:].strip()
|
||||
|
||||
# Check if there is non-empty content after </think>
|
||||
if not post_think_content:
|
||||
return None
|
||||
|
||||
return {"think": think_content, "answer": post_think_content}
|
||||
|
||||
|
||||
def extract_and_parse_json(input_string, wrapper):
|
||||
"""
|
||||
Try to extract and parse JSON from a string.
|
||||
|
||||
Args:
|
||||
input_string: The input string
|
||||
wrapper: JSON wrapper symbols, can be '{}' or '[]'
|
||||
|
||||
Returns:
|
||||
Parsed JSON object, returns None if parsing fails
|
||||
"""
|
||||
if len(wrapper) != 2:
|
||||
raise ValueError("Wrapper must be exactly two characters long")
|
||||
|
||||
start_char, end_char = wrapper
|
||||
start_index = input_string.find(start_char)
|
||||
|
||||
if start_index == -1:
|
||||
return None
|
||||
|
||||
# Find the matching end character by balancing brackets/braces
|
||||
balance = 1
|
||||
end_index = -1
|
||||
for i in range(start_index + 1, len(input_string)):
|
||||
if input_string[i] == start_char:
|
||||
balance += 1
|
||||
elif input_string[i] == end_char:
|
||||
balance -= 1
|
||||
|
||||
if balance == 0:
|
||||
end_index = i
|
||||
break
|
||||
|
||||
if end_index == -1:
|
||||
return None
|
||||
|
||||
json_string = input_string[start_index : end_index + 1]
|
||||
|
||||
try:
|
||||
return json.loads(json_string)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AER Reward Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _extract_verifiable_answer(answer):
|
||||
"""
|
||||
Extract and verify the format of the point list from the answer string.
|
||||
|
||||
A valid format is a JSON list of dictionaries, where each dictionary
|
||||
has a "point_2d" key with a list of two numbers as the value.
|
||||
|
||||
Args:
|
||||
answer: The answer string to extract points from
|
||||
|
||||
Returns:
|
||||
List of valid points or None if format is invalid
|
||||
"""
|
||||
points = extract_and_parse_json(answer, "[]")
|
||||
if points is None or not isinstance(points, list):
|
||||
return None
|
||||
|
||||
# Verify each point in the list
|
||||
for point in points:
|
||||
if isinstance(point, dict) and "point_2d" in point:
|
||||
point_2d = point["point_2d"]
|
||||
if isinstance(point_2d, list) and len(point_2d) == 2:
|
||||
continue
|
||||
|
||||
# If any point is malformed, the whole answer is invalid
|
||||
return None
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def _format_reward(answer):
|
||||
"""
|
||||
Calculate the format reward for 'point' type data.
|
||||
|
||||
This function is now primarily used as a check to see if the format is valid.
|
||||
|
||||
Args:
|
||||
answer: The answer string to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (reward, is_collinear) where reward is 1.0 for valid format, 0.0 otherwise
|
||||
"""
|
||||
points = _extract_verifiable_answer(answer)
|
||||
if points is None:
|
||||
return 0.0, 0
|
||||
|
||||
points_2d = [item["point_2d"] for item in points]
|
||||
if _check_collinear(points_2d):
|
||||
return 0.0, 1
|
||||
|
||||
return 1.0, 0
|
||||
|
||||
|
||||
def _check_collinear(points_2d):
|
||||
"""
|
||||
Check if 3 or more points in the list are collinear on any straight line.
|
||||
|
||||
This uses the cross-product method to avoid division and handle all line types.
|
||||
|
||||
Args:
|
||||
points_2d: A list of [x, y] coordinates
|
||||
|
||||
Returns:
|
||||
True if 3 or more points are collinear, False otherwise
|
||||
"""
|
||||
if len(points_2d) < 3:
|
||||
return False
|
||||
|
||||
# Iterate through all unique combinations of 3 points
|
||||
for p1, p2, p3 in combinations(points_2d, 3):
|
||||
x1, y1 = p1
|
||||
x2, y2 = p2
|
||||
x3, y3 = p3
|
||||
|
||||
# Check for collinearity using the cross-product method.
|
||||
# If (y2 - y1) * (x3 - x1) == (y3 - y1) * (x2 - x1), the points are collinear.
|
||||
# This is equivalent to checking if the area of the triangle formed by the points is 0.
|
||||
if math.isclose((y2 - y1) * (x3 - x1), (y3 - y1) * (x2 - x1)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _accuracy_reward(answer, ground_truth):
|
||||
"""
|
||||
Calculate the accuracy reward based on the symmetric zero-centered formula.
|
||||
|
||||
The reward is in the range [-1, 1].
|
||||
|
||||
Args:
|
||||
answer: The answer string containing predicted points
|
||||
ground_truth: Ground truth bounding box dictionary
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- accuracy (float): The calculated reward
|
||||
- extracted_answer (str): The JSON string of the predicted points
|
||||
- num_pred (int): The number of predicted points
|
||||
- first_correct (int): 1 if the first predicted point is correct, 0 otherwise
|
||||
"""
|
||||
pred_points = _extract_verifiable_answer(answer)
|
||||
|
||||
# If no valid points are extracted, this is considered a format error, return -1 reward
|
||||
if pred_points is None:
|
||||
return -1.0, "", 0, 0
|
||||
|
||||
num_pred = len(pred_points)
|
||||
extracted_answer = json.dumps(pred_points)
|
||||
|
||||
if num_pred == 0:
|
||||
return -1.0, extracted_answer, 0, 0
|
||||
|
||||
# Find the rank 'k' of the first correct point
|
||||
first_correct_rank = -1
|
||||
for i, item in enumerate(pred_points):
|
||||
point_2d = item["point_2d"]
|
||||
if (
|
||||
ground_truth["x1"] <= point_2d[0] <= ground_truth["x2"]
|
||||
and ground_truth["y1"] <= point_2d[1] <= ground_truth["y2"]
|
||||
):
|
||||
first_correct_rank = i + 1 # 1-based index
|
||||
break
|
||||
|
||||
# Calculate reward based on the zero-centered symmetric formula
|
||||
accuracy = 0.0
|
||||
if first_correct_rank != -1:
|
||||
# Case a: Correct point found (Positive reward space)
|
||||
k = first_correct_rank
|
||||
accuracy = 1.0 / math.sqrt(num_pred * k)
|
||||
else:
|
||||
# Case b: No correct point found (Negative reward space)
|
||||
accuracy = -1.0 / num_pred
|
||||
|
||||
first_correct = 1 if first_correct_rank == 1 else 0
|
||||
|
||||
return accuracy, extracted_answer, num_pred, first_correct
|
||||
|
||||
|
||||
def calculate_point_reward(solution_str, ground_truth, extra_info=None, fmt_ratio=1.0, acc_ratio=1.0, **kwargs):
|
||||
"""
|
||||
Calculate the final reward for 'point' type data.
|
||||
|
||||
Implements the full logic including format checks, collinearity checks,
|
||||
and the zero-centered symmetric reward calculation.
|
||||
|
||||
Args:
|
||||
solution_str: The solution string from the model
|
||||
ground_truth: Ground truth data
|
||||
extra_info: Extra information dictionary
|
||||
fmt_ratio: Format reward ratio
|
||||
acc_ratio: Accuracy reward ratio
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Dictionary containing detailed reward information
|
||||
"""
|
||||
if extra_info.get("no_think", False):
|
||||
answer = solution_str
|
||||
else:
|
||||
solution_dict = extract_think_format(solution_str)
|
||||
# If the overall 'think'/'answer' format is wrong, return score of -1
|
||||
if solution_dict is None:
|
||||
return {
|
||||
"score": -1.0,
|
||||
"format": 0.0,
|
||||
"accuracy": -1.0,
|
||||
"pred": "",
|
||||
"num_pred": 0,
|
||||
"has_correct": 0,
|
||||
"first_correct": 0,
|
||||
"only_correct": 0,
|
||||
"is_collinear": 0,
|
||||
}
|
||||
|
||||
answer = solution_dict["answer"]
|
||||
|
||||
# Reuse _format_reward to check the format of the 'answer' part
|
||||
# If it's invalid, return score of -1
|
||||
format_reward, is_collinear = _format_reward(answer)
|
||||
if format_reward == 0.0:
|
||||
return {
|
||||
"score": -1.0,
|
||||
"format": 0.0,
|
||||
"accuracy": -1.0,
|
||||
"pred": "",
|
||||
"num_pred": 0,
|
||||
"has_correct": 0,
|
||||
"first_correct": 0,
|
||||
"only_correct": 0,
|
||||
"is_collinear": is_collinear,
|
||||
}
|
||||
|
||||
# If format is OK, calculate the accuracy reward
|
||||
accuracy_reward, extracted_answer, num_pred, first_correct = _accuracy_reward(answer, ground_truth)
|
||||
|
||||
return {
|
||||
"score": fmt_ratio * format_reward + acc_ratio * accuracy_reward,
|
||||
"format": format_reward,
|
||||
"accuracy": accuracy_reward,
|
||||
"pred": extracted_answer,
|
||||
"num_pred": num_pred,
|
||||
"has_correct": 1 if accuracy_reward > 0 else 0,
|
||||
"first_correct": first_correct,
|
||||
"only_correct": 1 if num_pred == 1 and accuracy_reward > 0 else 0,
|
||||
"is_collinear": 0,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AER Reward Handler Registry
|
||||
# ============================================================================
|
||||
|
||||
# Dictionary to map data_source to the respective reward calculation function
|
||||
AER_REWARD_HANDLERS = {
|
||||
"point": calculate_point_reward,
|
||||
}
|
||||
|
||||
|
||||
def aer_gui_reward_function(data_source, solution_str, ground_truth, extra_info=None, **kwargs):
|
||||
"""
|
||||
Main reward function dispatcher for the Adaptive Exploration Reward (AER) system.
|
||||
|
||||
Delegates reward calculation to specific functions based on the data_source using a dictionary lookup.
|
||||
|
||||
Args:
|
||||
data_source: The source or type of the data (e.g., "point", "bbox")
|
||||
solution_str: The solution string generated by the model
|
||||
ground_truth: The ground truth data
|
||||
extra_info: Any extra information passed along (optional)
|
||||
**kwargs: Additional keyword arguments that might be passed from the PPO trainer config
|
||||
|
||||
Returns:
|
||||
Dictionary containing detailed reward information with keys:
|
||||
- score: The final calculated reward score
|
||||
- format: Format validation score
|
||||
- accuracy: Accuracy score
|
||||
- pred: Extracted prediction string
|
||||
- num_pred: Number of predictions
|
||||
- has_correct: Whether any correct prediction exists
|
||||
- first_correct: Whether first prediction is correct
|
||||
- only_correct: Whether only one correct prediction exists
|
||||
- is_collinear: Whether points are collinear (for point type)
|
||||
"""
|
||||
handler = AER_REWARD_HANDLERS.get(data_source, None)
|
||||
|
||||
if handler:
|
||||
try:
|
||||
return handler(
|
||||
solution_str, ground_truth, extra_info=extra_info, fmt_ratio=FMT_RATIO, acc_ratio=ACC_RATIO, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
f"Error executing reward handler for data_source '{data_source}': {e}",
|
||||
)
|
||||
return {
|
||||
"score": -1.0,
|
||||
"format": 0.0,
|
||||
"accuracy": -1.0,
|
||||
"pred": "",
|
||||
"num_pred": 0,
|
||||
"has_correct": 0,
|
||||
"first_correct": 0,
|
||||
"only_correct": 0,
|
||||
"is_collinear": 0,
|
||||
} # Return a default penalty score on error
|
||||
else:
|
||||
raise ValueError(f"Unknown data_source: '{data_source}'. No specific reward handler defined.")
|
55
recipe/infigui-g1/run_3b.sh
Normal file
55
recipe/infigui-g1/run_3b.sh
Normal file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
ulimit -n 65535
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=rloo \
|
||||
data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \
|
||||
data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \
|
||||
data.train_batch_size=128 \
|
||||
data.max_prompt_length=7168 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.truncation='error' \
|
||||
data.image_key=images \
|
||||
custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \
|
||||
custom_reward_function.name=aer_gui_reward_function \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
actor_rollout_ref.model.enable_activation_offload=True \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=False \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
||||
actor_rollout_ref.actor.clip_ratio_high=0.4 \
|
||||
actor_rollout_ref.actor.use_kl_loss=False \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=False \
|
||||
actor_rollout_ref.rollout.enforce_eager=False \
|
||||
actor_rollout_ref.rollout.free_cache_engine=True \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.rollout.temperature=1.0 \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name='infigui-g1' \
|
||||
trainer.experiment_name='3b' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=16 \
|
||||
trainer.test_freq=16 \
|
||||
trainer.total_epochs=6
|
55
recipe/infigui-g1/run_7b.sh
Normal file
55
recipe/infigui-g1/run_7b.sh
Normal file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
ulimit -n 65535
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=rloo \
|
||||
data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \
|
||||
data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \
|
||||
data.train_batch_size=128 \
|
||||
data.max_prompt_length=7168 \
|
||||
data.max_response_length=1024 \
|
||||
data.filter_overlong_prompts=False \
|
||||
data.truncation='error' \
|
||||
data.image_key=images \
|
||||
custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \
|
||||
custom_reward_function.name=aer_gui_reward_function \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
actor_rollout_ref.model.enable_activation_offload=True \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=False \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
||||
actor_rollout_ref.actor.clip_ratio_high=0.4 \
|
||||
actor_rollout_ref.actor.use_kl_loss=False \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.name=sglang \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=False \
|
||||
actor_rollout_ref.rollout.enforce_eager=False \
|
||||
actor_rollout_ref.rollout.free_cache_engine=True \
|
||||
actor_rollout_ref.rollout.n=8 \
|
||||
actor_rollout_ref.rollout.temperature=1.0 \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.project_name='infigui-g1' \
|
||||
trainer.experiment_name='7b' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=16 \
|
||||
trainer.test_freq=16 \
|
||||
trainer.total_epochs=6
|
Reference in New Issue
Block a user