[recipe] feat: Add InfiGUI-G1 recipe for MLLM GUI grounding (#3242)

### What does this PR do?

This PR introduces a new recipe, `infigui-g1`, for training Multimodal
Large Language Models (MLLMs) in GUI grounding tasks. This recipe
implements a reinforcement learning approach that significantly improves
the model's ability to understand and interact with graphical user
interfaces.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here:
https://github.com/search?q=repo%3Avolcengine%2Fverl+gui&type=pullrequests
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

The effectiveness of this recipe has been validated through experiments.
Key results are as follows:
- The training curves for reward, validation accuracy, and exploration
success rate all show a upward trend.
- After 156 steps of training on sample data, the 3b model achieves a
score of **41.2** on the `screenspot-pro` benchmark, a substantial
improvement over the base model's score of **18.2**.
<img width="345" height="291" alt="Screenshot 2025-08-27 172010"
src="https://github.com/user-attachments/assets/9ecd93d5-4f9b-4c40-831c-79a50fd197c4"
/>
<img width="347" height="292" alt="Screenshot 2025-08-27 171902"
src="https://github.com/user-attachments/assets/2e437c1f-9eb0-4106-a6c3-b22125026a79"
/>
<img width="346" height="293" alt="Screenshot 2025-08-27 171928"
src="https://github.com/user-attachments/assets/9c94515d-1501-40f4-979c-95e2f819dc62"
/>

### API and Usage Example

The recipe is self-contained and can be run using the provided scripts.
For example, to run training with the 3B parameter model:

```bash
# In verl path
bash recipe/infigui-g1/run_3b.sh
```

### Design & Code Changes

This PR adds a new, independent recipe located in `recipe/infigui-g1/`.
The changes are fully encapsulated within this directory and do not
affect any other part of the codebase.

The new files include:
- `recipe/infigui-g1/README.md`: An introduction to the recipe.
- `recipe/infigui-g1/run_3b.sh`, `run_7b.sh`: Scripts to launch
training.
- `recipe/infigui-g1/reward_fn.py`: Custom reward function
implementation.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Yuhang Liu
2025-08-27 23:35:22 +08:00
committed by GitHub
parent 53b68c638b
commit 1e413344a2
4 changed files with 554 additions and 0 deletions

View File

@ -0,0 +1,56 @@
# Recipe for InfiGUI-G1
This directory contains the official implementation for the paper [InfiGUI-G1: Advancing GUI Grounding with Adaptive Exploration Policy Optimization](https://arxiv.org/abs/2508.05731).
This work introduces Adaptive Exploration Policy Optimization (AEPO), a policy optimization framework designed to enhance GUI grounding in Multimodal Large Language Models (MLLMs). AEPO improves exploration efficiency by employing a multi-answer generation strategy and a theoretically grounded Adaptive Exploration Reward (AER) function. This approach effectively addresses the challenge of semantic alignment in complex GUI grounding tasks.
We provide training scripts for both 3B and 7B models, configured for a single machine with 8 GPUs by default.
## Environment Setup
Please follow the main environment setup guide for `verl`.
The provided scripts use the following Docker image: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`
## Data Preparation
Before starting the training, you need to download the example dataset. This dataset is a filtered version of [omniact](https://huggingface.co/datasets/Writer/omniact), containing only grounding tasks and excluding easy samples.
The data is hosted on the Hugging Face. You can download it using the `huggingface-cli`:
```bash
huggingface-cli download --repo-type dataset --resume-download InfiX-ai/omniact_grounding_filtered --local-dir data/omniact_grounding_filtered
```
This command will download the training and validation parquet files into the `data/omniact_grounding_filtered` directory, which is the default path used by the scripts.
## Training
We provide scripts to train the 3B and 7B models. Please run them from the root directory of `verl`.
- **Train the 3B model:**
```bash
bash recipe/infigui-g1/run_3b.sh
```
- **Train the 7B model:**
```bash
bash recipe/infigui-g1/run_7b.sh
```
## Using Custom Data
If you wish to train on your own dataset, please format your data to match the structure of the example files located in `data/omniact_grounding_filtered`.
Once your data is ready, you need to update the data path arguments in the training script.
In `run_3b.sh` or `run_7b.sh`, modify the following lines:
```bash
data.train_files=./path/to/your/train_data.parquet \
data.val_files=./path/to/your/val_data.parquet \
```
Replace the paths with the location of your custom data files.

View File

@ -0,0 +1,388 @@
# Copyright 2025 Individual Contributor: InfiX.ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import math
import re
from itertools import combinations
FMT_RATIO = 1.0
ACC_RATIO = 1.0
# ============================================================================
# Utility Functions
# ============================================================================
def extract_think_format(predict_str: str) -> None | dict[str, str]:
"""
Check if the predicted string meets format requirements and extract thinking and answer parts.
Args:
predict_str: The predicted string
Returns:
If format requirements are met, returns a dictionary containing thinking and answer parts;
otherwise returns None
"""
if not predict_str or not isinstance(predict_str, str):
return None
# Check if <think> is at the beginning
if not predict_str.startswith("<think>"):
return None
# Check if there is <think>...</think> format
pattern = r"<think>(.*?)</think>"
think_match = re.search(pattern, predict_str, re.DOTALL)
if not think_match:
return None
if predict_str.count("<think>") != 1 or predict_str.count("</think>") != 1:
return None
# Extract thinking content
think_content = think_match.group(1).strip()
if not think_content:
return None
# Get content after </think>
think_end_pos = predict_str.find("</think>") + len("</think>")
post_think_content = predict_str[think_end_pos:].strip()
# Check if there is non-empty content after </think>
if not post_think_content:
return None
return {"think": think_content, "answer": post_think_content}
def extract_and_parse_json(input_string, wrapper):
"""
Try to extract and parse JSON from a string.
Args:
input_string: The input string
wrapper: JSON wrapper symbols, can be '{}' or '[]'
Returns:
Parsed JSON object, returns None if parsing fails
"""
if len(wrapper) != 2:
raise ValueError("Wrapper must be exactly two characters long")
start_char, end_char = wrapper
start_index = input_string.find(start_char)
if start_index == -1:
return None
# Find the matching end character by balancing brackets/braces
balance = 1
end_index = -1
for i in range(start_index + 1, len(input_string)):
if input_string[i] == start_char:
balance += 1
elif input_string[i] == end_char:
balance -= 1
if balance == 0:
end_index = i
break
if end_index == -1:
return None
json_string = input_string[start_index : end_index + 1]
try:
return json.loads(json_string)
except json.JSONDecodeError:
return None
# ============================================================================
# AER Reward Functions
# ============================================================================
def _extract_verifiable_answer(answer):
"""
Extract and verify the format of the point list from the answer string.
A valid format is a JSON list of dictionaries, where each dictionary
has a "point_2d" key with a list of two numbers as the value.
Args:
answer: The answer string to extract points from
Returns:
List of valid points or None if format is invalid
"""
points = extract_and_parse_json(answer, "[]")
if points is None or not isinstance(points, list):
return None
# Verify each point in the list
for point in points:
if isinstance(point, dict) and "point_2d" in point:
point_2d = point["point_2d"]
if isinstance(point_2d, list) and len(point_2d) == 2:
continue
# If any point is malformed, the whole answer is invalid
return None
return points
def _format_reward(answer):
"""
Calculate the format reward for 'point' type data.
This function is now primarily used as a check to see if the format is valid.
Args:
answer: The answer string to validate
Returns:
Tuple of (reward, is_collinear) where reward is 1.0 for valid format, 0.0 otherwise
"""
points = _extract_verifiable_answer(answer)
if points is None:
return 0.0, 0
points_2d = [item["point_2d"] for item in points]
if _check_collinear(points_2d):
return 0.0, 1
return 1.0, 0
def _check_collinear(points_2d):
"""
Check if 3 or more points in the list are collinear on any straight line.
This uses the cross-product method to avoid division and handle all line types.
Args:
points_2d: A list of [x, y] coordinates
Returns:
True if 3 or more points are collinear, False otherwise
"""
if len(points_2d) < 3:
return False
# Iterate through all unique combinations of 3 points
for p1, p2, p3 in combinations(points_2d, 3):
x1, y1 = p1
x2, y2 = p2
x3, y3 = p3
# Check for collinearity using the cross-product method.
# If (y2 - y1) * (x3 - x1) == (y3 - y1) * (x2 - x1), the points are collinear.
# This is equivalent to checking if the area of the triangle formed by the points is 0.
if math.isclose((y2 - y1) * (x3 - x1), (y3 - y1) * (x2 - x1)):
return True
return False
def _accuracy_reward(answer, ground_truth):
"""
Calculate the accuracy reward based on the symmetric zero-centered formula.
The reward is in the range [-1, 1].
Args:
answer: The answer string containing predicted points
ground_truth: Ground truth bounding box dictionary
Returns:
Tuple containing:
- accuracy (float): The calculated reward
- extracted_answer (str): The JSON string of the predicted points
- num_pred (int): The number of predicted points
- first_correct (int): 1 if the first predicted point is correct, 0 otherwise
"""
pred_points = _extract_verifiable_answer(answer)
# If no valid points are extracted, this is considered a format error, return -1 reward
if pred_points is None:
return -1.0, "", 0, 0
num_pred = len(pred_points)
extracted_answer = json.dumps(pred_points)
if num_pred == 0:
return -1.0, extracted_answer, 0, 0
# Find the rank 'k' of the first correct point
first_correct_rank = -1
for i, item in enumerate(pred_points):
point_2d = item["point_2d"]
if (
ground_truth["x1"] <= point_2d[0] <= ground_truth["x2"]
and ground_truth["y1"] <= point_2d[1] <= ground_truth["y2"]
):
first_correct_rank = i + 1 # 1-based index
break
# Calculate reward based on the zero-centered symmetric formula
accuracy = 0.0
if first_correct_rank != -1:
# Case a: Correct point found (Positive reward space)
k = first_correct_rank
accuracy = 1.0 / math.sqrt(num_pred * k)
else:
# Case b: No correct point found (Negative reward space)
accuracy = -1.0 / num_pred
first_correct = 1 if first_correct_rank == 1 else 0
return accuracy, extracted_answer, num_pred, first_correct
def calculate_point_reward(solution_str, ground_truth, extra_info=None, fmt_ratio=1.0, acc_ratio=1.0, **kwargs):
"""
Calculate the final reward for 'point' type data.
Implements the full logic including format checks, collinearity checks,
and the zero-centered symmetric reward calculation.
Args:
solution_str: The solution string from the model
ground_truth: Ground truth data
extra_info: Extra information dictionary
fmt_ratio: Format reward ratio
acc_ratio: Accuracy reward ratio
**kwargs: Additional keyword arguments
Returns:
Dictionary containing detailed reward information
"""
if extra_info.get("no_think", False):
answer = solution_str
else:
solution_dict = extract_think_format(solution_str)
# If the overall 'think'/'answer' format is wrong, return score of -1
if solution_dict is None:
return {
"score": -1.0,
"format": 0.0,
"accuracy": -1.0,
"pred": "",
"num_pred": 0,
"has_correct": 0,
"first_correct": 0,
"only_correct": 0,
"is_collinear": 0,
}
answer = solution_dict["answer"]
# Reuse _format_reward to check the format of the 'answer' part
# If it's invalid, return score of -1
format_reward, is_collinear = _format_reward(answer)
if format_reward == 0.0:
return {
"score": -1.0,
"format": 0.0,
"accuracy": -1.0,
"pred": "",
"num_pred": 0,
"has_correct": 0,
"first_correct": 0,
"only_correct": 0,
"is_collinear": is_collinear,
}
# If format is OK, calculate the accuracy reward
accuracy_reward, extracted_answer, num_pred, first_correct = _accuracy_reward(answer, ground_truth)
return {
"score": fmt_ratio * format_reward + acc_ratio * accuracy_reward,
"format": format_reward,
"accuracy": accuracy_reward,
"pred": extracted_answer,
"num_pred": num_pred,
"has_correct": 1 if accuracy_reward > 0 else 0,
"first_correct": first_correct,
"only_correct": 1 if num_pred == 1 and accuracy_reward > 0 else 0,
"is_collinear": 0,
}
# ============================================================================
# AER Reward Handler Registry
# ============================================================================
# Dictionary to map data_source to the respective reward calculation function
AER_REWARD_HANDLERS = {
"point": calculate_point_reward,
}
def aer_gui_reward_function(data_source, solution_str, ground_truth, extra_info=None, **kwargs):
"""
Main reward function dispatcher for the Adaptive Exploration Reward (AER) system.
Delegates reward calculation to specific functions based on the data_source using a dictionary lookup.
Args:
data_source: The source or type of the data (e.g., "point", "bbox")
solution_str: The solution string generated by the model
ground_truth: The ground truth data
extra_info: Any extra information passed along (optional)
**kwargs: Additional keyword arguments that might be passed from the PPO trainer config
Returns:
Dictionary containing detailed reward information with keys:
- score: The final calculated reward score
- format: Format validation score
- accuracy: Accuracy score
- pred: Extracted prediction string
- num_pred: Number of predictions
- has_correct: Whether any correct prediction exists
- first_correct: Whether first prediction is correct
- only_correct: Whether only one correct prediction exists
- is_collinear: Whether points are collinear (for point type)
"""
handler = AER_REWARD_HANDLERS.get(data_source, None)
if handler:
try:
return handler(
solution_str, ground_truth, extra_info=extra_info, fmt_ratio=FMT_RATIO, acc_ratio=ACC_RATIO, **kwargs
)
except Exception as e:
logging.exception(
f"Error executing reward handler for data_source '{data_source}': {e}",
)
return {
"score": -1.0,
"format": 0.0,
"accuracy": -1.0,
"pred": "",
"num_pred": 0,
"has_correct": 0,
"first_correct": 0,
"only_correct": 0,
"is_collinear": 0,
} # Return a default penalty score on error
else:
raise ValueError(f"Unknown data_source: '{data_source}'. No specific reward handler defined.")

View File

@ -0,0 +1,55 @@
#!/bin/bash
set -x
ulimit -n 65535
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=rloo \
data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \
data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \
data.train_batch_size=128 \
data.max_prompt_length=7168 \
data.max_response_length=1024 \
data.filter_overlong_prompts=False \
data.truncation='error' \
data.image_key=images \
custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \
custom_reward_function.name=aer_gui_reward_function \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
actor_rollout_ref.model.enable_activation_offload=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=False \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.clip_ratio_high=0.4 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.logger=['console','wandb'] \
trainer.project_name='infigui-g1' \
trainer.experiment_name='3b' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=16 \
trainer.test_freq=16 \
trainer.total_epochs=6

View File

@ -0,0 +1,55 @@
#!/bin/bash
set -x
ulimit -n 65535
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=rloo \
data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \
data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \
data.train_batch_size=128 \
data.max_prompt_length=7168 \
data.max_response_length=1024 \
data.filter_overlong_prompts=False \
data.truncation='error' \
data.image_key=images \
custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \
custom_reward_function.name=aer_gui_reward_function \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
actor_rollout_ref.model.enable_activation_offload=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=False \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.clip_ratio_high=0.4 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.logger=['console','wandb'] \
trainer.project_name='infigui-g1' \
trainer.experiment_name='7b' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=16 \
trainer.test_freq=16 \
trainer.total_epochs=6