Compare commits

..

7 Commits

Author SHA1 Message Date
9bf8db4887 Merge branch 'main' into docs/unify-trl-lib-namespace 2025-11-04 17:58:49 -08:00
5dfb2db0c1 docs: Use Qwen3-0.6B model as requested by reviewer
Update all model references in use_model.md to use Qwen/Qwen3-0.6B
as specifically requested by qgallouedec.

Changes:
- Replace Qwen/Qwen2.5-0.5B with Qwen/Qwen3-0.6B in all 3 locations
- Simpler model reference consistent with reviewer's suggestion
2025-11-04 17:56:19 -08:00
c34de94903 docs: Use official Qwen model instead of trl-lib namespace
Address reviewer feedback by replacing trl-lib/Qwen2-0.5B-XPO with the
official Qwen/Qwen2.5-0.5B model in all use_model.md examples.

Changes:
- Replace model references in 3 locations to use Qwen organization model
- More consistent with rest of TRL documentation
- Less misleading than custom trl-lib namespace model
2025-11-04 17:54:37 -08:00
800a4d928a Merge branch 'main' into docs/unify-trl-lib-namespace 2025-11-04 15:35:48 -08:00
6f906d5087 Apply suggestion from @qgallouedec 2025-11-04 16:32:09 -07:00
91e540ce09 Merge branch 'main' into docs/unify-trl-lib-namespace 2025-11-03 10:09:02 -08:00
580c6bb951 docs: Unify model examples to use trl-lib namespace
Resolves #4385

- Replace edbeeching/gpt-neo-125M-imdb with trl-lib/Qwen2-0.5B-XPO in peft_integration.md
- Replace kashif/stack-llama-2 with trl-lib/Qwen2-0.5B-XPO in use_model.md (3 occurrences)
- All personal developer namespace models now use common trl-lib namespace
2025-11-02 13:48:26 -08:00
15 changed files with 361 additions and 838 deletions

View File

@ -25,7 +25,7 @@ jobs:
steps:
- name: Git checkout
uses: actions/checkout@v4
with: { ref: v0.25-release }
with: { ref: v0.24-release }
- name: Set up Python 3.12
uses: actions/setup-python@v5

View File

@ -31,4 +31,4 @@ keywords:
- pytorch
- transformers
license: Apache-2.0
version: "0.25"
version: "0.24"

View File

@ -1 +1 @@
0.26.0.dev0
0.25.0.dev0

View File

@ -52,14 +52,6 @@ Community tutorials are made by active members of the Hugging Face community who
| Visual QA | [`DPOTrainer`] | Fine-Tuning a Vision Language Model with TRL using MPO | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_mpo.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Post training a VLM for reasoning with GRPO using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_grpo_trl.ipynb) |
## Speech Language Models
### Tutorials
| Task | Class | Description | Author | Tutorial |
| --- | --- | --- | --- | --- |
| Text-to-Speech | [`GRPOTrainer`] | Post training a Speech Language Model with GRPO using TRL | [Steven Zheng](https://huggingface.co/Steveeeeeeen) | [Link](https://huggingface.co/blog/Steveeeeeeen/llasa-grpo) |
## Contributing
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.

View File

@ -11,7 +11,7 @@ In this guide, well focus on **how to integrate OpenEnv with TRL**, but feel
To use OpenEnv with TRL, install the framework:
```bash
pip install git+https://github.com/meta-pytorch/OpenEnv.git
pip install openenv-core
```
## Using `rollout_func` with OpenEnv environments
@ -65,33 +65,6 @@ By using OpenEnv in this loop, you can:
* Plug in custom simulators, web APIs, or evaluators as environments.
* Pass structured reward signals back into RL training seamlessly.
## Running the Environments
You can run OpenEnv environments in three different ways:
1. **Local Docker container** *(recommended)*
To start a Docker container:
* Open the environment on the Hugging Face Hub.
* Click the **⋮ (three dots)** menu.
* Select **“Run locally.”**
* Copy and execute the provided command in your terminal.
Example:
```bash
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
```
![open_env_launch_docker](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/open_env_launch_docker.png)
2. **Local Python process**: Launch the environment directly using Uvicorn.
You can start the server manually as a local process. For more details about the available environments, refer to the [OpenEnv repository](https://github.com/meta-pytorch/OpenEnv/tree/main/src/envs).
```bash
python -m uvicorn envs.echo_env.server.app:app --host 0.0.0.0 --port 8001
```
3. **Hugging Face Spaces**: Connect to a hosted environment running on the Hugging Face Hub.
To find the connection URL, open the Space page, click the **⋮ (three dots)** menu, and select **“Embed this Space.”**
You can then use that URL to connect directly from your client.
Keep in mind that public Spaces may have rate limits or temporarily go offline if inactive.
## A simple example
The [echo.py](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards:
@ -102,15 +75,6 @@ from trl import GRPOConfig, GRPOTrainer
# Create HTTP client for Echo Environment
client = EchoEnv.from_docker_image("echo-env:latest")
"""
Alternatively, you can start the environment manually with Docker and connect to it:
# Step 1: Start the Echo environment
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
# Step 2: Connect the client to the running container
client = EchoEnv(base_url="http://0.0.0.0:8001")
"""
def rollout_func(prompts, args, processing_class):
# 1. Generate completions via vLLM inference server (running on port 8000)
@ -187,21 +151,6 @@ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
```
Alternatively, you can manually start the Echo environment in a Docker container before running the training:
```bash
# Launch the Echo environment
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
```
Then, initialize the client using:
`client = EchoEnv(base_url="http://0.0.0.0:8001")`
instead of:
`client = EchoEnv.from_docker_image("echo-env:latest")`.
Below is the reward curve from training:
<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
@ -403,33 +352,22 @@ trainer = GRPOTrainer(
trainer.train()
```
### Running the Advanced Example
### Running the Example
The example requires two GPUs:
```bash
# Terminal 1: Start vLLM inference server
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
# Terminal 2: Run GRPO training with OpenEnv
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
```
Again, you can manually start the TextArena environment in a Docker container before running the training.
In this case, initialize the client with
`client = TextArenaEnv(base_url="http://0.0.0.0:8001")`
instead of
`client = TextArenaEnv.from_docker_image("registry.hf.space/burtenshaw-textarena:latest")`:
```bash
# Launch the TextArena environment
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
```
### Results
The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters.
<iframe src="https://burtenshaw-wordle-grpo.hf.space?project=group-Qwen-Qwen3-17B&metrics=reward&runs=run-2025-10-26_09-39-49,run-2025-10-26_08-04-49&sidebar=hidden&navbar=hidden" style="width:1600px; height:500px; border:0;"></iframe>
<iframe src="https://burtenshaw-wordle-grpo.hf.space/?project=group-Qwen-Qwen3-17B&metrics=train/rewards/reward_coverage/mean&runs=run-2025-10-26_09-39-49&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself?

View File

@ -9,7 +9,7 @@ If you have fine-tuned a model fully, meaning without the use of PEFT you can si
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
model_name_or_path = "Qwen/Qwen3-0.6B" #path/to/your/model/or/name/on/hub
device = "cpu" # or "cuda" if you have a GPU
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
@ -25,7 +25,7 @@ Alternatively you can also use the pipeline:
```python
from transformers import pipeline
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
model_name_or_path = "Qwen/Qwen3-0.6B" #path/to/your/model/or/name/on/hub
pipe = pipeline("text-generation", model=model_name_or_path)
print(pipe("This movie was really")[0]["generated_text"])
```
@ -36,7 +36,7 @@ print(pipe("This movie was really")[0]["generated_text"])
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
base_model_name = "Qwen/Qwen3-0.6B" #path/to/your/model/or/name/on/hub
adapter_model_name = "path/to/my/adapter"
model = AutoModelForCausalLM.from_pretrained(base_model_name)

View File

@ -12,38 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function
is based on the catch game where the agent tries to catch falling balls.
Setup:
```sh
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
```
Usage (2 GPUs required):
# Start the docker container for the Catch environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
```sh
docker run -d -p 8001:8001 registry.hf.space/openenv-openspiel-env:latest
```
# Spin up vLLM server
```sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
```
# Run training
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py
```
"""
# ruff: noqa: T201
import argparse
import os
import re
import subprocess
@ -59,79 +28,34 @@ from envs.openspiel_env.models import OpenSpielAction
from trl import GRPOConfig, GRPOTrainer, RichProgressCallback, apply_chat_template
def parse_args():
parser = argparse.ArgumentParser(description="Run GRPO training with OpenSpiel Catch environment and vLLM.")
"""
Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function
is based on the catch game where the agent tries to catch falling balls.
# --- Environment settings ---
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.")
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
parser.add_argument(
"--env-mode",
choices=["local", "docker", "space"],
default="docker",
help="Where to run the environment: 'local', 'docker', or 'space'.",
)
# --- Generation and model config ---
parser.add_argument(
"--gen-url",
type=str,
default="http://0.0.0.0:8000/generate/",
help="vLLM generation endpoint URL.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-0.5B-Instruct",
help="Model name or path.",
)
parser.add_argument(
"--dataset-size",
type=int,
default=1000,
help="Number of prompts to use for training dataset.",
)
Setup:
return parser.parse_args()
```sh
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
uv pip install open_spiel rich trackio
```
Usage (2 GPUs required):
def start_env_server(env_host: str, env_port: int):
"""Launch the OpenSpiel Catch environment locally via uvicorn."""
env_url = f"http://{env_host}:{env_port}"
print(f"⚡ Starting FastAPI server for OpenSpiel Catch Environment on {env_url}...")
# Spin up vLLM server
work_dir = str(Path.cwd().parent.absolute())
process = subprocess.Popen(
[
sys.executable,
"-m",
"uvicorn",
"envs.openspiel_env.server.app:app",
"--host",
env_host,
"--port",
str(env_port),
],
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=work_dir,
)
```sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
```
print("⏳ Waiting for server to start...")
time.sleep(5)
# Run training
try:
requests.get(f"{env_url}/health", timeout=2)
print("\n✅ OpenSpiel Catch Environment server is running!")
except Exception as e:
print(f"\n❌ Server failed to start: {e}")
if process.stderr:
print(process.stderr.read())
raise
return process
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py
```
"""
GEN_URL = "http://0.0.0.0:8000/generate/"
ENV_URL = "http://0.0.0.0:8001"
BASE_PROMPT = """You are an AI agent playing the game **Catch**.
@ -144,64 +68,135 @@ BASE_PROMPT = """You are an AI agent playing the game **Catch**.
- You get **1 reward** if you miss it.
### Observation Format
Each observation is a flattened 10x5 grid (list of 50 floats).
- 1.0 → occupied (ball or paddle)
- 0.0 → empty cell
### Actions:
- `0` → Move left
- `1` → Stay
- `2` → Move right
- `observation`: a list of **50 numbers (floats)** representing the entire grid, flattened row by row.
- Each cell contains `1.0` if it is occupied (either by the ball or the paddle), or `0.0` if it is empty.
- The positions of the two `1.0` values indicate where the **ball** and **paddle** currently are.
- `legal_actions`: a list of integers representing which actions are currently allowed.
Respond **only** with one integer: `0`, `1`, or `2`.
### Actions Each action is a discrete integer:
- `0` → Move paddle **left**
- `1` → **Stay** (no movement)
- `2` → Move paddle **right**
### Output Format Respond **only with one integer** representing your chosen action: `0`, `1`, or `2`.
### Current Observation
"""
# Start the OpenSpiel server in background
print("⚡ Starting FastAPI server for OpenSpiel Catch Environment...")
def rollout_func(
prompts: list[str], args: GRPOConfig, processing_class, client: OpenSpielEnv, gen_url: str
) -> dict[str, list]:
"""Generate completions via vLLM and compute environment rewards."""
# Determine the correct path
work_dir = str(Path.cwd().parent.absolute())
server_process = subprocess.Popen(
[sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"],
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=work_dir,
)
print("⏳ Waiting for server to start...")
time.sleep(5)
# Check if server is running
try:
response = requests.get(f"{ENV_URL}/health", timeout=2)
print("\n✅ OpenSpiel Catch Environment server is running!")
except Exception as e:
print(f"\n❌ Server failed to start: {e}")
print("\n📋 Checking error output...")
server_process.poll()
if server_process.stderr:
stderr = server_process.stderr.read()
if stderr:
print(stderr)
raise
# Create HTTP client for OpenSpiel Catch Environment
client = OpenSpielEnv(base_url=f"{ENV_URL}")
def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]:
"""
Custom rollout function that generates completions via vLLM server and computes environment rewards.
The catch game expects action IDs (integers). We'll parse the model's text output to extract action choices.
Args:
prompts: List of prompts to generate from
args: GRPOConfig containing all sampling parameters
processing_class: Tokenizer/processor for decoding completions
Returns:
Dict containing prompt_ids, completion_ids, logprobs, and env_reward
"""
# Run full episodes for each generation to get episode rewards
env_rewards = []
all_prompt_ids, all_completion_ids, all_logprobs = [], [], []
all_prompt_ids = []
all_completion_ids = []
all_logprobs = []
for base_prompt in prompts:
for _ in range(args.num_generations):
# Run episode: Reset environment and loop until done
env_result = client.reset()
obs = env_result.observation
total_reward = 0.0
episode_prompt_ids, episode_completion_ids, episode_logprobs = [], [], []
episode_prompt_ids = []
episode_completion_ids = []
episode_logprobs = []
# TODO: parallelise!
while not obs.done:
# FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset
episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]}
episode_prompt = apply_chat_template(episode_msg, processing_class)
payload = {
# Generate action from model
gen_payload = {
"prompts": [episode_prompt["prompt"]],
"n": 1,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": -1 if args.top_k is None else args.top_k,
"min_p": 0.0 if args.min_p is None else args.min_p,
"max_tokens": args.max_completion_length,
"repetition_penalty": args.repetition_penalty,
}
response = requests.post(gen_url, json=payload)
response.raise_for_status()
result = response.json()
gen_response = requests.post(GEN_URL, json=gen_payload)
gen_response.raise_for_status()
gen_result = gen_response.json()
episode_prompt_ids.extend(result["prompt_ids"][0])
episode_completion_ids.extend(result["completion_ids"][0])
episode_logprobs.extend(result["logprobs"][0])
# Collect prompt_ids, completion_ids, and logprobs from this step
episode_prompt_ids.extend(gen_result["prompt_ids"][0])
episode_completion_ids.extend(gen_result["completion_ids"][0])
episode_logprobs.extend(gen_result["logprobs"][0])
completion_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)[0]
completion_text = processing_class.batch_decode(
gen_result["completion_ids"], skip_special_tokens=True
)[0]
# Parse action from completion
action_id = 0 # default
numbers = re.findall(r"\b([0-2])\b", completion_text)
action_id = int(numbers[0]) if numbers else obs.legal_actions[0]
if numbers:
action_id = int(numbers[0])
elif obs.legal_actions:
action_id = obs.legal_actions[0]
# Take action in environment
env_result = client.step(OpenSpielAction(action_id=action_id, game_name="catch"))
total_reward += env_result.reward or 0.0
reward = env_result.reward if env_result.reward is not None else 0.0
total_reward += reward
obs = env_result.observation
# Store episode results
env_rewards.append(total_reward)
all_prompt_ids.append(episode_prompt_ids)
all_completion_ids.append(episode_completion_ids)
@ -215,60 +210,42 @@ def rollout_func(
}
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * 1000})
def reward_from_env(completions, **kwargs):
rewards = kwargs.get("env_reward", [])
return [float(r) for r in rewards] if rewards else [0.0] * len(completions)
def main():
args = parse_args()
# Select environment mode
if args.env_mode == "local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = start_env_server(args.env_host, args.env_port)
elif args.env_mode == "docker":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = None
print(f"🌍 Using existing Docker environment at {env_url}")
elif args.env_mode == "space":
env_url = args.env_host
server_process = None
print(f"🚀 Using Hugging Face Space environment at {env_url}")
"""Reward function that uses the environment reward from the catch game."""
# Extract environment rewards from kwargs (propagated via extra_fields)
env_rewards = kwargs.get("env_reward", [])
if env_rewards:
return [float(reward) for reward in env_rewards]
else:
raise ValueError(f"Unknown env mode: {args.env_mode}")
gen_url = args.gen_url
client = OpenSpielEnv(base_url=env_url)
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size})
training_args = GRPOConfig(
output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch",
vllm_mode="server",
use_vllm=True,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
max_completion_length=4,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model=args.model,
reward_funcs=reward_from_env,
args=training_args,
train_dataset=dataset,
rollout_func=lambda p, a, pc: rollout_func(p, a, pc, client, gen_url),
callbacks=[RichProgressCallback()],
)
trainer.train()
time.sleep(5)
if server_process:
print("🛑 Terminating environment server...")
server_process.terminate()
# Fallback if env_reward is not available
return [0.0] * len(completions)
if __name__ == "__main__":
main()
training_args = GRPOConfig(
output_dir="Qwen2.5-0.5B-GRPO-Catch",
vllm_mode="server",
use_vllm=True,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
max_completion_length=4,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_from_env,
args=training_args,
train_dataset=dataset,
rollout_func=rollout_func,
callbacks=[RichProgressCallback()],
)
trainer.train()
# Give time for background threads to finish
time.sleep(5)
print("🛑 Terminating environment server...")
server_process.terminate()

View File

@ -12,38 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages
longer completions.
Setup:
```sh
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
```
Usage (2 GPUs required):
# Start the docker container for the Echo environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
```sh
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
```
# Spin up server
```sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
```
# Run training
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
```
"""
# ruff: noqa: T201
import argparse
import os
import subprocess
import sys
@ -58,73 +27,80 @@ from envs.echo_env.models import EchoAction
from trl import GRPOConfig, GRPOTrainer, RichProgressCallback
def parse_args():
parser = argparse.ArgumentParser(description="Run GRPO training with Echo environment and vLLM.")
"""
Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages
longer completions.
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the Echo environment.")
parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
parser.add_argument(
"--env-mode",
choices=["local", "docker", "space"],
default="docker",
help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.",
)
parser.add_argument(
"--gen-url",
type=str,
default="http://0.0.0.0:8000/generate/",
help="Base URL for the vLLM generation endpoint.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2.5-0.5B-Instruct",
help="Model to use for training.",
)
parser.add_argument(
"--dataset",
type=str,
default="trl-lib/ultrafeedback-prompt",
help="Dataset to use for training.",
)
Setup:
return parser.parse_args()
```sh
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
```
Usage (2 GPUs required):
# Spin up server
```sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
```
# Run training
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
```
"""
GEN_URL = "http://0.0.0.0:8000/generate/"
ENV_URL = "http://0.0.0.0:8001"
print("⚡ Starting FastAPI server for Echo Environment...")
# Workaround if you can't run the env with Docker
work_dir = str(Path.cwd().parent.absolute())
server_process = subprocess.Popen(
[sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"],
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=work_dir,
)
print("⏳ Waiting for server to start...")
time.sleep(5)
try:
response = requests.get(f"{ENV_URL}/health", timeout=2)
print("\n✅ Echo Environment server is running!")
except Exception as e:
print(f"\n❌ Server failed to start: {e}")
print("\n📋 Checking error output...")
server_process.poll()
if server_process.stderr:
stderr = server_process.stderr.read()
if stderr:
print(stderr)
raise
def start_env_server(env_host: str, env_port: int):
"""Launch the Echo environment server locally."""
env_url = f"http://{env_host}:{env_port}"
print(f"⚡ Starting FastAPI server for Echo Environment on {env_url}...")
work_dir = str(Path.cwd().parent.absolute())
process = subprocess.Popen(
[sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", env_host, "--port", str(env_port)],
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=work_dir,
)
print("⏳ Waiting for server to start...")
time.sleep(5)
try:
requests.get(f"{env_url}/health", timeout=2)
print("\n✅ Echo Environment server is running!")
except Exception as e:
print(f"\n❌ Server failed to start: {e}")
if process.stderr:
print(process.stderr.read())
raise
return process
# Create HTTP client for Echo Environment
client = EchoEnv(base_url=f"{ENV_URL}")
def rollout_func(
prompts: list[str], args: GRPOConfig, processing_class, client: EchoEnv, gen_url: str
) -> dict[str, list]:
"""Generate completions via vLLM and compute environment rewards."""
def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]:
"""
Custom rollout function that generates completions via vLLM server and computes environment rewards.
Args:
prompts: List of prompts to generate from
args: GRPOConfig containing all sampling parameters
processing_class: Tokenizer/processor for decoding completions
Returns:
Dict containing prompt_ids, completion_ids, logprobs, and env_reward
"""
# 1. Generate completions via vLLM inference server (running on port 8000)
payload = {
"prompts": prompts,
"n": args.num_generations,
@ -135,80 +111,64 @@ def rollout_func(
"max_tokens": args.max_completion_length,
"repetition_penalty": args.repetition_penalty,
}
response = requests.post(GEN_URL, json=payload)
response = requests.post(gen_url, json=payload)
if response.status_code != 200:
print(f"Error response: {response.text}")
response.raise_for_status()
response.raise_for_status()
result = response.json()
completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)
# 2. Step through the environment to get rewards
env_result = client.reset()
env_rewards = []
for msg in completions_text:
env_result = client.step(EchoAction(message=msg))
env_rewards.append(env_result.reward)
# 3. Add environment rewards as extra field
result["env_reward"] = env_rewards
return result
def reward_from_env(completions, **kwargs):
"""Extract environment rewards for training."""
"""Reward function that uses the environment reward."""
# Extract environment rewards from kwargs (propagated via extra_fields)
env_rewards = kwargs.get("env_reward", [])
return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions)
def main():
args = parse_args()
# Select environment mode
if args.env_mode == "local":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = start_env_server(args.env_host, args.env_port)
elif args.env_mode == "docker":
env_url = f"http://{args.env_host}:{args.env_port}"
server_process = None
print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
elif args.env_mode == "space":
env_url = args.env_host
server_process = None
print(f"🚀 Using Hugging Face Space environment at: {env_url}")
if env_rewards:
return [float(reward) for reward in env_rewards]
else:
raise ValueError(f"Unknown environment mode: {args.env_mode}")
gen_url = args.gen_url
client = EchoEnv(base_url=env_url)
dataset = load_dataset(args.dataset, split="train[:1000]")
training_args = GRPOConfig(
output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
vllm_mode="server",
use_vllm=True,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
max_completion_length=2048,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model=args.model,
reward_funcs=reward_from_env,
args=training_args,
train_dataset=dataset,
rollout_func=lambda p, a, pc: rollout_func(p, a, pc, client, gen_url),
callbacks=[RichProgressCallback()],
)
trainer.train()
time.sleep(5)
if server_process:
print("🛑 Terminating Echo Environment server...")
server_process.terminate()
# Fallback if env_reward is not available
return [0.0] * len(completions)
if __name__ == "__main__":
main()
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]")
training_args = GRPOConfig(
output_dir="Qwen2.5-0.5B-GRPO-Rollout",
vllm_mode="server",
use_vllm=True,
logging_steps=1,
report_to="trackio",
num_train_epochs=1,
max_completion_length=2048,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_from_env,
args=training_args,
train_dataset=dataset,
rollout_func=rollout_func,
callbacks=[RichProgressCallback()],
)
trainer.train()
# Give time for background threads to finish
time.sleep(5)
print("🛑 Terminating Echo Environment server...")
server_process.terminate()

View File

@ -13,33 +13,18 @@
# limitations under the License.
"""
Simple script to run GRPO training with OpenEnv's Wordle environment and a vLLM server.
GRPO training for Wordle using TRL's `GRPOTrainer` and the TextArena OpenEnv environment.
Setup:
Usage:
# First, start the TextArena Wordle server (Docker or local):
TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 \
python -m src.envs.textarena_env.server.app
```sh
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
```
# Start the vLLM server with your model
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
Usage (2 GPUs required):
# Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
```sh
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app
```
# Spin up vLLM server
```sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
```
# Run training
```sh
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
```
# Then run this training script:
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
"""
from __future__ import annotations
@ -81,12 +66,12 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument(
"--model-id",
default="Qwen/Qwen3-1.7B",
default="willcb/Qwen3-1.7B-Wordle",
help="Model identifier passed to GRPOTrainer for fine-tuning.",
)
parser.add_argument(
"--env-url",
default="https://0.0.0.0:8001", # default="https://burtenshaw-textarena.hf.space"
"--textarena-url",
default="https://burtenshaw-textarena.hf.space",
help="Base URL for the TextArena Wordle environment.",
)
parser.add_argument(
@ -520,7 +505,7 @@ def main() -> None:
tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token
env = TextArenaEnv(base_url=cli_args.env_url)
env = TextArenaEnv(base_url=cli_args.textarena_url)
system_prompt = resolve_system_prompt(cli_args.system_prompt_path)

View File

@ -99,26 +99,6 @@ class TestVLLMClientServer(TrlTestCase):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
@ -200,26 +180,6 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
@ -303,26 +263,6 @@ class TestVLLMClientServerTP(TrlTestCase):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
@ -386,26 +326,6 @@ class TestVLLMClientServerDP(TrlTestCase):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of sequences are equal to the number of messages
assert len(prompt_ids) == len(messages)
assert len(completion_ids) == len(messages)
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)

View File

@ -19,7 +19,7 @@ from typing import Any
import torch
from accelerate.utils import gather_object
from ...data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages
from ...data_utils import apply_chat_template, is_conversational
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
from ...trainer.utils import nanmax, nanmin, nanstd, pad
@ -81,17 +81,8 @@ class GFPOTrainer(_GRPOTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
if images is not None:
prompts = [
prepare_multimodal_messages(prompt, image_list)
for prompt, image_list in zip(prompts, images, strict=True)
]
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = (
self._generate(prompts)
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
prompts, images
)
# Convert lists of token IDs to padded tensors
@ -290,15 +281,6 @@ class GFPOTrainer(_GRPOTrainer):
agg_completion_lengths = self.accelerator.gather(completion_lengths)
num_items_in_batch = agg_completion_lengths.sum()
if sampling_per_token_logps is not None:
sampling_per_token_logps = sampling_per_token_logps[local_input_indices_to_keep].contiguous()
if old_per_token_logps is not None:
old_per_token_logps = old_per_token_logps[local_input_indices_to_keep].contiguous()
if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps[local_input_indices_to_keep].contiguous()
if self.use_vllm and self.vllm_importance_sampling_correction:
importance_sampling_ratio = importance_sampling_ratio[local_input_indices_to_keep].contiguous()
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
@ -327,10 +309,15 @@ class GFPOTrainer(_GRPOTrainer):
self._logs["advantages"].extend(all_process_advantages.tolist())
if images is not None:
self._logs["images"].extend(all_images)
self._logs["images"].extend(gather_object(images))
if self.use_vllm and self.vllm_importance_sampling_correction:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
if self.num_remains_in_group is not None and mode == "train":
delta = delta[local_input_indices_to_keep]
importance_sampling_ratio = importance_sampling_ratio[local_input_indices_to_keep]
delta = delta[completion_mask.bool()]
mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)

View File

@ -14,7 +14,6 @@
import atexit
import base64
import copy
import logging
import socket
import time
@ -45,13 +44,6 @@ if is_vllm_available():
logger = logging.getLogger(__name__)
def pil_to_base64(image):
buffer = BytesIO()
image.save(buffer, format="PNG")
img_bytes = buffer.getvalue()
return base64.b64encode(img_bytes).decode("utf-8")
class VLLMClient:
"""
A client class to interact with a vLLM server.
@ -192,7 +184,7 @@ class VLLMClient:
truncate_prompt_tokens: int | None = None,
guided_decoding_regex: str | None = None,
generation_kwargs: dict | None = None,
) -> dict[str, list[list[int]]]:
) -> list[list[int]]:
"""
Generates model completions for the provided prompts.
@ -237,6 +229,12 @@ class VLLMClient:
"""
url = f"{self.base_url}/generate/"
def pil_to_base64(image):
buffer = BytesIO()
image.save(buffer, format="PNG")
img_bytes = buffer.getvalue()
return base64.b64encode(img_bytes).decode("utf-8")
# Convert PIL images to base64 strings
images = [pil_to_base64(img) for img in images] if images else None
@ -267,102 +265,6 @@ class VLLMClient:
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
def chat(
self,
messages: list[list[dict]],
n: int = 1,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
truncate_prompt_tokens: int | None = None,
guided_decoding_regex: str | None = None,
generation_kwargs: dict | None = None,
chat_template_kwargs: dict | None = None,
) -> dict[str, list[list[int]]]:
"""
Generates model completions for the provided chat messages.
Args:
messages (`list[list[dict]]`):
List of message lists for which the model will generate completions. Each message is a dictionary with
keys like "role" and "content".
n (`int`, *optional*, defaults to `1`):
Number of completions to generate for each message list.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Parameter for repetition penalty. 1.0 means no penalty.
temperature (`float`, *optional*, defaults to `1.0`):
Temperature parameter for sampling. Higher values increase diversity.
top_p (`float`, *optional*, defaults to `1.0`):
Top-p sampling parameter.`1.0` means no truncation.
top_k (`int`, *optional*, defaults to `-1`):
Top-k sampling parameter. `-1` means no truncation.
min_p (`float`, *optional*, defaults to `0.0`):
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each message list.
truncate_prompt_tokens (`int`, *optional*):
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
guided_decoding_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
`seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they
will override them.
chat_template_kwargs (`dict`, *optional*):
Additional keyword arguments to customize the chat template used by the model.
Returns:
`dict` with keys:
- `prompt_ids` (`list[list[int]]`):
List of lists of token IDs representing the tokenized input messages.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each message list.
- `logprobs` (`list[list[float]]`):
List of lists of log probabilities for each generated token.
"""
url = f"{self.base_url}/chat/"
# Convert PIL images to base64 strings
messages = copy.deepcopy(messages) # avoid modifying the original messages
for message_list in messages:
for message in message_list:
if isinstance(message["content"], list):
for part in message["content"]:
if part["type"] == "image_pil":
part["image_pil"] = pil_to_base64(part["image_pil"])
response = self.session.post(
url,
json={
"messages": messages,
"n": n,
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
"guided_decoding_regex": guided_decoding_regex,
"generation_kwargs": generation_kwargs or {},
"chat_template_kwargs": chat_template_kwargs or {},
},
)
if response.status_code == 200:
json_response = response.json()
return {
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
}
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
def init_communicator(self, device: torch.device | str | int = 0):
"""
Initializes the weight update group in a distributed setup for model synchronization.

View File

@ -580,7 +580,7 @@ def main(script_args: ScriptArguments):
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
"guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
"logprobs": 0,
}
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
@ -615,143 +615,6 @@ def main(script_args: ScriptArguments):
]
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
class ChatRequest(BaseModel):
messages: list[list[dict]]
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
truncate_prompt_tokens: int | None = None
guided_decoding_regex: str | None = None
generation_kwargs: dict = field(default_factory=dict)
chat_template_kwargs: dict = field(default_factory=dict)
class ChatResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[float]]
@app.post("/chat/", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""
Generates completions for the provided chat messages.
Args:
request (`ChatRequest`):
- `messages` (list of `dict`): A list of messages (dicts with "role" and "content" keys) for the model
to generate completions.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during
generation.
- `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead
to more random outputs.
- `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the
diversity of the generated text.
- `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables
top-k sampling.
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
keys that conflict with the other parameters, they will override them.
- `chat_template_kwargs` (`dict`, *optional*): Additional keyword arguments to pass to the chat
template.
Returns:
`ChatResponse`:
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
generated completions.
Example request:
```bash
curl -X POST 'http://0.0.0.0:8000/chat/' \
-H 'Content-Type: application/json' \
-d '{"messages": [[{ "role": "user", "content": "Hello!" }]]}'
```
Example response:
```json
{
"prompt_ids": [[151644, 872, 198, 9707, 0, 151645, 198, 151644, 77091, 198]],
"completion_ids":[[151667, 198, 32313, 11, 279, 1196, 1101, 1053, 330, 9707, 8958, 773, 358, 1184, 311, 5889]],
"logprobs": [[-0.00029404606902971864, -3.576278118089249e-07, -0.09024181962013245, -6.389413465512916e-05, -0.038671817630529404, -0.00013314791431184858, -0.5868351459503174, -0.09682723134756088, -0.06609706580638885, -0.00023803261865396053, -0.02242819033563137, -0.8185162544250488, -0.04954879730939865, -0.3169460594654083, -4.887569048150908e-06, -0.006023705471307039]]
}
```
"""
# Convert PIL images to base64 strings
for message_list in request.messages:
for message in message_list:
if isinstance(message["content"], list):
for part in message["content"]:
if part["type"] == "image_pil":
part["image_pil"] = Image.open(BytesIO(base64.b64decode(part["image_pil"])))
# Guided decoding, if enabled
if request.guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex)
else:
guided_decoding = None
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
"guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
# Evenly distribute prompts across DP ranks
chunked_messages = chunk_list(request.messages, script_args.data_parallel_size)
# Send the messages to each worker
for connection, messages in zip(connections, chunked_messages, strict=True):
# When the number of messages is less than data_parallel_size, some workers will receive empty messages.
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
# with vLLM's requirement, and we later ignore the result.
if not messages:
messages = [[{"role": "user", "content": "<placeholder>"}]]
kwargs = {
"messages": messages,
"sampling_params": sampling_params,
"chat_template_kwargs": request.chat_template_kwargs,
}
connection.send({"type": "call", "method": "chat", "kwargs": kwargs})
# Receive results
all_outputs = [connection.recv() for connection in connections]
# Handle empty prompts (see above)
all_outputs = [output for output, prompts in zip(all_outputs, chunked_messages, strict=True) if prompts]
# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
prompt_ids = [output.prompt_token_ids for output in all_outputs]
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
logprobs: list[list[float]] = [
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
class InitCommunicatorRequest(BaseModel):
host: str
port: int

View File

@ -24,7 +24,6 @@ from pathlib import Path
from typing import Any
import datasets
import pandas as pd
import torch
import torch.utils.data
import transformers
@ -1191,8 +1190,9 @@ class GRPOTrainer(BaseTrainer):
)
else:
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
# FIXME: this endpoint doesn't exist in vllm_client
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
prompts=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=self.chat_template_kwargs,
)
@ -1246,7 +1246,7 @@ class GRPOTrainer(BaseTrainer):
"max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
"logprobs": 0, # only return the logprob of the generated token
}
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
@ -1922,45 +1922,45 @@ class GRPOTrainer(BaseTrainer):
self.num_completions_to_print,
)
table = {
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
"advantage": self._logs["advantages"],
}
logging_backends = []
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
logging_backends.append(wandb)
if self.args.report_to and "trackio" in self.args.report_to:
logging_backends.append(trackio)
df_base = pd.DataFrame(table)
images_raw = self._logs["images"] or []
if logging_backends:
import pandas as pd
for logging_backend in self.args.report_to:
if logging_backend == "wandb":
table = {
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
"advantage": self._logs["advantages"],
}
df_base = pd.DataFrame(table)
images_raw = self._logs["images"] or []
for logging_backend in logging_backends:
if images_raw:
images = []
for image_list in self._logs["images"]:
images.append([wandb.Image(image) for image in image_list])
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
# Convert images per backend and derive a dataframe that shares base columns
if logging_backend is wandb:
images = []
for image_list in self._logs["images"]:
images.append([wandb.Image(image) for image in image_list])
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
elif logging_backend is trackio:
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
logger.info("Skipping image logging for Trackio")
df = df_base
else:
df = df_base
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})
if logging_backend == "trackio":
if images_raw:
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/334
logger.info("Skipping image logging for Trackio")
df = df_base
# images = []
# for image_list in self._logs["images"]:
# images.append([trackio.Image(image) for image in image_list])
# df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
else:
df = df_base
trackio.log({"completions": trackio.Table(dataframe=df)})
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):

View File

@ -24,7 +24,6 @@ from pathlib import Path
from typing import Any
import datasets
import pandas as pd
import torch
import torch.utils.data
import transformers
@ -1011,7 +1010,7 @@ class RLOOTrainer(BaseTrainer):
with profiling_context(self, "vLLM.generate"):
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
prompts=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=self.chat_template_kwargs,
)
@ -1531,45 +1530,45 @@ class RLOOTrainer(BaseTrainer):
self.num_completions_to_print,
)
table = {
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
"advantage": self._logs["advantages"],
}
logging_backends = []
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
logging_backends.append(wandb)
if self.args.report_to and "trackio" in self.args.report_to:
logging_backends.append(trackio)
df_base = pd.DataFrame(table)
images_raw = self._logs["images"] or []
if logging_backends:
import pandas as pd
for logging_backend in self.args.report_to:
if logging_backend == "wandb":
table = {
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
"advantage": self._logs["advantages"],
}
df_base = pd.DataFrame(table)
images_raw = self._logs["images"] or []
for logging_backend in logging_backends:
if images_raw:
images = []
for image_list in self._logs["images"]:
images.append([wandb.Image(image) for image in image_list])
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
# Convert images per backend and derive a dataframe that shares base columns
if logging_backend is wandb:
images = []
for image_list in self._logs["images"]:
images.append([wandb.Image(image) for image in image_list])
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
elif logging_backend is trackio:
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
logger.info("Skipping image logging for Trackio")
df = df_base
else:
df = df_base
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})
if logging_backend == "trackio":
if images_raw:
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/334
logger.info("Skipping image logging for Trackio")
df = df_base
# images = []
# for image_list in self._logs["images"]:
# images.append([trackio.Image(image) for image in image_list])
# df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
else:
df = df_base
trackio.log({"completions": trackio.Table(dataframe=df)})
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):