mirror of
https://github.com/huggingface/trl.git
synced 2025-11-06 14:24:29 +08:00
Compare commits
7 Commits
main
...
docs/unify
| Author | SHA1 | Date | |
|---|---|---|---|
| 9bf8db4887 | |||
| 5dfb2db0c1 | |||
| c34de94903 | |||
| 800a4d928a | |||
| 6f906d5087 | |||
| 91e540ce09 | |||
| 580c6bb951 |
2
.github/workflows/tests_latest.yml
vendored
2
.github/workflows/tests_latest.yml
vendored
@ -25,7 +25,7 @@ jobs:
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.25-release }
|
||||
with: { ref: v0.24-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
|
||||
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: "0.25"
|
||||
version: "0.24"
|
||||
|
||||
@ -52,14 +52,6 @@ Community tutorials are made by active members of the Hugging Face community who
|
||||
| Visual QA | [`DPOTrainer`] | Fine-Tuning a Vision Language Model with TRL using MPO | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_mpo.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training a VLM for reasoning with GRPO using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_grpo_trl.ipynb) |
|
||||
|
||||
## Speech Language Models
|
||||
|
||||
### Tutorials
|
||||
|
||||
| Task | Class | Description | Author | Tutorial |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| Text-to-Speech | [`GRPOTrainer`] | Post training a Speech Language Model with GRPO using TRL | [Steven Zheng](https://huggingface.co/Steveeeeeeen) | [Link](https://huggingface.co/blog/Steveeeeeeen/llasa-grpo) |
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
|
||||
|
||||
@ -11,7 +11,7 @@ In this guide, we’ll focus on **how to integrate OpenEnv with TRL**, but feel
|
||||
To use OpenEnv with TRL, install the framework:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
pip install openenv-core
|
||||
```
|
||||
|
||||
## Using `rollout_func` with OpenEnv environments
|
||||
@ -65,33 +65,6 @@ By using OpenEnv in this loop, you can:
|
||||
* Plug in custom simulators, web APIs, or evaluators as environments.
|
||||
* Pass structured reward signals back into RL training seamlessly.
|
||||
|
||||
## Running the Environments
|
||||
|
||||
You can run OpenEnv environments in three different ways:
|
||||
|
||||
1. **Local Docker container** *(recommended)*
|
||||
|
||||
To start a Docker container:
|
||||
* Open the environment on the Hugging Face Hub.
|
||||
* Click the **⋮ (three dots)** menu.
|
||||
* Select **“Run locally.”**
|
||||
* Copy and execute the provided command in your terminal.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
|
||||
```
|
||||

|
||||
2. **Local Python process**: Launch the environment directly using Uvicorn.
|
||||
You can start the server manually as a local process. For more details about the available environments, refer to the [OpenEnv repository](https://github.com/meta-pytorch/OpenEnv/tree/main/src/envs).
|
||||
```bash
|
||||
python -m uvicorn envs.echo_env.server.app:app --host 0.0.0.0 --port 8001
|
||||
```
|
||||
3. **Hugging Face Spaces**: Connect to a hosted environment running on the Hugging Face Hub.
|
||||
To find the connection URL, open the Space page, click the **⋮ (three dots)** menu, and select **“Embed this Space.”**
|
||||
You can then use that URL to connect directly from your client.
|
||||
Keep in mind that public Spaces may have rate limits or temporarily go offline if inactive.
|
||||
|
||||
## A simple example
|
||||
|
||||
The [echo.py](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards:
|
||||
@ -102,15 +75,6 @@ from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
# Create HTTP client for Echo Environment
|
||||
client = EchoEnv.from_docker_image("echo-env:latest")
|
||||
"""
|
||||
Alternatively, you can start the environment manually with Docker and connect to it:
|
||||
|
||||
# Step 1: Start the Echo environment
|
||||
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
|
||||
|
||||
# Step 2: Connect the client to the running container
|
||||
client = EchoEnv(base_url="http://0.0.0.0:8001")
|
||||
"""
|
||||
|
||||
def rollout_func(prompts, args, processing_class):
|
||||
# 1. Generate completions via vLLM inference server (running on port 8000)
|
||||
@ -187,21 +151,6 @@ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
|
||||
```
|
||||
|
||||
Alternatively, you can manually start the Echo environment in a Docker container before running the training:
|
||||
|
||||
```bash
|
||||
# Launch the Echo environment
|
||||
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
|
||||
```
|
||||
|
||||
Then, initialize the client using:
|
||||
|
||||
`client = EchoEnv(base_url="http://0.0.0.0:8001")`
|
||||
|
||||
instead of:
|
||||
|
||||
`client = EchoEnv.from_docker_image("echo-env:latest")`.
|
||||
|
||||
Below is the reward curve from training:
|
||||
|
||||
<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
|
||||
@ -403,33 +352,22 @@ trainer = GRPOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Running the Advanced Example
|
||||
### Running the Example
|
||||
|
||||
The example requires two GPUs:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM inference server
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
|
||||
# Terminal 2: Run GRPO training with OpenEnv
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
|
||||
```
|
||||
|
||||
Again, you can manually start the TextArena environment in a Docker container before running the training.
|
||||
In this case, initialize the client with
|
||||
`client = TextArenaEnv(base_url="http://0.0.0.0:8001")`
|
||||
instead of
|
||||
`client = TextArenaEnv.from_docker_image("registry.hf.space/burtenshaw-textarena:latest")`:
|
||||
|
||||
```bash
|
||||
# Launch the TextArena environment
|
||||
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
|
||||
```
|
||||
|
||||
### Results
|
||||
|
||||
The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters.
|
||||
|
||||
<iframe src="https://burtenshaw-wordle-grpo.hf.space?project=group-Qwen-Qwen3-17B&metrics=reward&runs=run-2025-10-26_09-39-49,run-2025-10-26_08-04-49&sidebar=hidden&navbar=hidden" style="width:1600px; height:500px; border:0;"></iframe>
|
||||
<iframe src="https://burtenshaw-wordle-grpo.hf.space/?project=group-Qwen-Qwen3-17B&metrics=train/rewards/reward_coverage/mean&runs=run-2025-10-26_09-39-49&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
|
||||
|
||||
We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself?
|
||||
@ -9,7 +9,7 @@ If you have fine-tuned a model fully, meaning without the use of PEFT you can si
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
model_name_or_path = "Qwen/Qwen3-0.6B" #path/to/your/model/or/name/on/hub
|
||||
device = "cpu" # or "cuda" if you have a GPU
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
|
||||
@ -25,7 +25,7 @@ Alternatively you can also use the pipeline:
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
model_name_or_path = "Qwen/Qwen3-0.6B" #path/to/your/model/or/name/on/hub
|
||||
pipe = pipeline("text-generation", model=model_name_or_path)
|
||||
print(pipe("This movie was really")[0]["generated_text"])
|
||||
```
|
||||
@ -36,7 +36,7 @@ print(pipe("This movie was really")[0]["generated_text"])
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
base_model_name = "Qwen/Qwen3-0.6B" #path/to/your/model/or/name/on/hub
|
||||
adapter_model_name = "path/to/my/adapter"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
||||
|
||||
@ -12,38 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function
|
||||
is based on the catch game where the agent tries to catch falling balls.
|
||||
|
||||
Setup:
|
||||
|
||||
```sh
|
||||
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
```
|
||||
|
||||
Usage (2 GPUs required):
|
||||
|
||||
# Start the docker container for the Catch environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
|
||||
```sh
|
||||
docker run -d -p 8001:8001 registry.hf.space/openenv-openspiel-env:latest
|
||||
```
|
||||
|
||||
# Spin up vLLM server
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
# Run training
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py
|
||||
```
|
||||
"""
|
||||
|
||||
# ruff: noqa: T201
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
@ -59,79 +28,34 @@ from envs.openspiel_env.models import OpenSpielAction
|
||||
from trl import GRPOConfig, GRPOTrainer, RichProgressCallback, apply_chat_template
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run GRPO training with OpenSpiel Catch environment and vLLM.")
|
||||
"""
|
||||
Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function
|
||||
is based on the catch game where the agent tries to catch falling balls.
|
||||
|
||||
# --- Environment settings ---
|
||||
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.")
|
||||
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
|
||||
parser.add_argument(
|
||||
"--env-mode",
|
||||
choices=["local", "docker", "space"],
|
||||
default="docker",
|
||||
help="Where to run the environment: 'local', 'docker', or 'space'.",
|
||||
)
|
||||
# --- Generation and model config ---
|
||||
parser.add_argument(
|
||||
"--gen-url",
|
||||
type=str,
|
||||
default="http://0.0.0.0:8000/generate/",
|
||||
help="vLLM generation endpoint URL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
help="Model name or path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to use for training dataset.",
|
||||
)
|
||||
Setup:
|
||||
|
||||
return parser.parse_args()
|
||||
```sh
|
||||
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
uv pip install open_spiel rich trackio
|
||||
```
|
||||
|
||||
Usage (2 GPUs required):
|
||||
|
||||
def start_env_server(env_host: str, env_port: int):
|
||||
"""Launch the OpenSpiel Catch environment locally via uvicorn."""
|
||||
env_url = f"http://{env_host}:{env_port}"
|
||||
print(f"⚡ Starting FastAPI server for OpenSpiel Catch Environment on {env_url}...")
|
||||
# Spin up vLLM server
|
||||
|
||||
work_dir = str(Path.cwd().parent.absolute())
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"envs.openspiel_env.server.app:app",
|
||||
"--host",
|
||||
env_host,
|
||||
"--port",
|
||||
str(env_port),
|
||||
],
|
||||
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
)
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
print("⏳ Waiting for server to start...")
|
||||
time.sleep(5)
|
||||
# Run training
|
||||
|
||||
try:
|
||||
requests.get(f"{env_url}/health", timeout=2)
|
||||
print("\n✅ OpenSpiel Catch Environment server is running!")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Server failed to start: {e}")
|
||||
if process.stderr:
|
||||
print(process.stderr.read())
|
||||
raise
|
||||
|
||||
return process
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py
|
||||
```
|
||||
"""
|
||||
|
||||
GEN_URL = "http://0.0.0.0:8000/generate/"
|
||||
ENV_URL = "http://0.0.0.0:8001"
|
||||
|
||||
BASE_PROMPT = """You are an AI agent playing the game **Catch**.
|
||||
|
||||
@ -144,64 +68,135 @@ BASE_PROMPT = """You are an AI agent playing the game **Catch**.
|
||||
- You get **–1 reward** if you miss it.
|
||||
|
||||
### Observation Format
|
||||
Each observation is a flattened 10x5 grid (list of 50 floats).
|
||||
- 1.0 → occupied (ball or paddle)
|
||||
- 0.0 → empty cell
|
||||
|
||||
### Actions:
|
||||
- `0` → Move left
|
||||
- `1` → Stay
|
||||
- `2` → Move right
|
||||
- `observation`: a list of **50 numbers (floats)** representing the entire grid, flattened row by row.
|
||||
- Each cell contains `1.0` if it is occupied (either by the ball or the paddle), or `0.0` if it is empty.
|
||||
- The positions of the two `1.0` values indicate where the **ball** and **paddle** currently are.
|
||||
- `legal_actions`: a list of integers representing which actions are currently allowed.
|
||||
|
||||
Respond **only** with one integer: `0`, `1`, or `2`.
|
||||
### Actions Each action is a discrete integer:
|
||||
- `0` → Move paddle **left**
|
||||
- `1` → **Stay** (no movement)
|
||||
- `2` → Move paddle **right**
|
||||
|
||||
### Output Format Respond **only with one integer** representing your chosen action: `0`, `1`, or `2`.
|
||||
|
||||
### Current Observation
|
||||
"""
|
||||
|
||||
# Start the OpenSpiel server in background
|
||||
print("⚡ Starting FastAPI server for OpenSpiel Catch Environment...")
|
||||
|
||||
def rollout_func(
|
||||
prompts: list[str], args: GRPOConfig, processing_class, client: OpenSpielEnv, gen_url: str
|
||||
) -> dict[str, list]:
|
||||
"""Generate completions via vLLM and compute environment rewards."""
|
||||
# Determine the correct path
|
||||
work_dir = str(Path.cwd().parent.absolute())
|
||||
|
||||
server_process = subprocess.Popen(
|
||||
[sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"],
|
||||
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
)
|
||||
|
||||
print("⏳ Waiting for server to start...")
|
||||
time.sleep(5)
|
||||
|
||||
# Check if server is running
|
||||
try:
|
||||
response = requests.get(f"{ENV_URL}/health", timeout=2)
|
||||
print("\n✅ OpenSpiel Catch Environment server is running!")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Server failed to start: {e}")
|
||||
print("\n📋 Checking error output...")
|
||||
server_process.poll()
|
||||
if server_process.stderr:
|
||||
stderr = server_process.stderr.read()
|
||||
if stderr:
|
||||
print(stderr)
|
||||
raise
|
||||
|
||||
|
||||
# Create HTTP client for OpenSpiel Catch Environment
|
||||
client = OpenSpielEnv(base_url=f"{ENV_URL}")
|
||||
|
||||
|
||||
def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]:
|
||||
"""
|
||||
Custom rollout function that generates completions via vLLM server and computes environment rewards.
|
||||
|
||||
The catch game expects action IDs (integers). We'll parse the model's text output to extract action choices.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts to generate from
|
||||
args: GRPOConfig containing all sampling parameters
|
||||
processing_class: Tokenizer/processor for decoding completions
|
||||
|
||||
Returns:
|
||||
Dict containing prompt_ids, completion_ids, logprobs, and env_reward
|
||||
"""
|
||||
# Run full episodes for each generation to get episode rewards
|
||||
env_rewards = []
|
||||
all_prompt_ids, all_completion_ids, all_logprobs = [], [], []
|
||||
all_prompt_ids = []
|
||||
all_completion_ids = []
|
||||
all_logprobs = []
|
||||
|
||||
for base_prompt in prompts:
|
||||
for _ in range(args.num_generations):
|
||||
# Run episode: Reset environment and loop until done
|
||||
env_result = client.reset()
|
||||
obs = env_result.observation
|
||||
total_reward = 0.0
|
||||
|
||||
episode_prompt_ids, episode_completion_ids, episode_logprobs = [], [], []
|
||||
episode_prompt_ids = []
|
||||
episode_completion_ids = []
|
||||
episode_logprobs = []
|
||||
|
||||
# TODO: parallelise!
|
||||
while not obs.done:
|
||||
# FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset
|
||||
episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]}
|
||||
episode_prompt = apply_chat_template(episode_msg, processing_class)
|
||||
|
||||
payload = {
|
||||
# Generate action from model
|
||||
gen_payload = {
|
||||
"prompts": [episode_prompt["prompt"]],
|
||||
"n": 1,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": -1 if args.top_k is None else args.top_k,
|
||||
"min_p": 0.0 if args.min_p is None else args.min_p,
|
||||
"max_tokens": args.max_completion_length,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}
|
||||
response = requests.post(gen_url, json=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
gen_response = requests.post(GEN_URL, json=gen_payload)
|
||||
gen_response.raise_for_status()
|
||||
gen_result = gen_response.json()
|
||||
|
||||
episode_prompt_ids.extend(result["prompt_ids"][0])
|
||||
episode_completion_ids.extend(result["completion_ids"][0])
|
||||
episode_logprobs.extend(result["logprobs"][0])
|
||||
# Collect prompt_ids, completion_ids, and logprobs from this step
|
||||
episode_prompt_ids.extend(gen_result["prompt_ids"][0])
|
||||
episode_completion_ids.extend(gen_result["completion_ids"][0])
|
||||
episode_logprobs.extend(gen_result["logprobs"][0])
|
||||
|
||||
completion_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)[0]
|
||||
completion_text = processing_class.batch_decode(
|
||||
gen_result["completion_ids"], skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
# Parse action from completion
|
||||
action_id = 0 # default
|
||||
numbers = re.findall(r"\b([0-2])\b", completion_text)
|
||||
action_id = int(numbers[0]) if numbers else obs.legal_actions[0]
|
||||
if numbers:
|
||||
action_id = int(numbers[0])
|
||||
elif obs.legal_actions:
|
||||
action_id = obs.legal_actions[0]
|
||||
|
||||
# Take action in environment
|
||||
env_result = client.step(OpenSpielAction(action_id=action_id, game_name="catch"))
|
||||
total_reward += env_result.reward or 0.0
|
||||
reward = env_result.reward if env_result.reward is not None else 0.0
|
||||
total_reward += reward
|
||||
obs = env_result.observation
|
||||
|
||||
# Store episode results
|
||||
env_rewards.append(total_reward)
|
||||
all_prompt_ids.append(episode_prompt_ids)
|
||||
all_completion_ids.append(episode_completion_ids)
|
||||
@ -215,60 +210,42 @@ def rollout_func(
|
||||
}
|
||||
|
||||
|
||||
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * 1000})
|
||||
|
||||
|
||||
def reward_from_env(completions, **kwargs):
|
||||
rewards = kwargs.get("env_reward", [])
|
||||
return [float(r) for r in rewards] if rewards else [0.0] * len(completions)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Select environment mode
|
||||
if args.env_mode == "local":
|
||||
env_url = f"http://{args.env_host}:{args.env_port}"
|
||||
server_process = start_env_server(args.env_host, args.env_port)
|
||||
elif args.env_mode == "docker":
|
||||
env_url = f"http://{args.env_host}:{args.env_port}"
|
||||
server_process = None
|
||||
print(f"🌍 Using existing Docker environment at {env_url}")
|
||||
elif args.env_mode == "space":
|
||||
env_url = args.env_host
|
||||
server_process = None
|
||||
print(f"🚀 Using Hugging Face Space environment at {env_url}")
|
||||
"""Reward function that uses the environment reward from the catch game."""
|
||||
# Extract environment rewards from kwargs (propagated via extra_fields)
|
||||
env_rewards = kwargs.get("env_reward", [])
|
||||
if env_rewards:
|
||||
return [float(reward) for reward in env_rewards]
|
||||
else:
|
||||
raise ValueError(f"Unknown env mode: {args.env_mode}")
|
||||
|
||||
gen_url = args.gen_url
|
||||
client = OpenSpielEnv(base_url=env_url)
|
||||
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size})
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch",
|
||||
vllm_mode="server",
|
||||
use_vllm=True,
|
||||
logging_steps=1,
|
||||
report_to="trackio",
|
||||
num_train_epochs=1,
|
||||
max_completion_length=4,
|
||||
gradient_accumulation_steps=4,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=args.model,
|
||||
reward_funcs=reward_from_env,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
rollout_func=lambda p, a, pc: rollout_func(p, a, pc, client, gen_url),
|
||||
callbacks=[RichProgressCallback()],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
time.sleep(5)
|
||||
|
||||
if server_process:
|
||||
print("🛑 Terminating environment server...")
|
||||
server_process.terminate()
|
||||
# Fallback if env_reward is not available
|
||||
return [0.0] * len(completions)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
training_args = GRPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-GRPO-Catch",
|
||||
vllm_mode="server",
|
||||
use_vllm=True,
|
||||
logging_steps=1,
|
||||
report_to="trackio",
|
||||
num_train_epochs=1,
|
||||
max_completion_length=4,
|
||||
gradient_accumulation_steps=4,
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
reward_funcs=reward_from_env,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
rollout_func=rollout_func,
|
||||
callbacks=[RichProgressCallback()],
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Give time for background threads to finish
|
||||
time.sleep(5)
|
||||
|
||||
print("🛑 Terminating environment server...")
|
||||
server_process.terminate()
|
||||
|
||||
@ -12,38 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages
|
||||
longer completions.
|
||||
|
||||
Setup:
|
||||
|
||||
```sh
|
||||
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
```
|
||||
|
||||
Usage (2 GPUs required):
|
||||
|
||||
# Start the docker container for the Echo environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
|
||||
```sh
|
||||
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
|
||||
```
|
||||
|
||||
# Spin up server
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
# Run training
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
|
||||
```
|
||||
"""
|
||||
|
||||
# ruff: noqa: T201
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
@ -58,73 +27,80 @@ from envs.echo_env.models import EchoAction
|
||||
from trl import GRPOConfig, GRPOTrainer, RichProgressCallback
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run GRPO training with Echo environment and vLLM.")
|
||||
"""
|
||||
Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages
|
||||
longer completions.
|
||||
|
||||
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the Echo environment.")
|
||||
parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
|
||||
parser.add_argument(
|
||||
"--env-mode",
|
||||
choices=["local", "docker", "space"],
|
||||
default="docker",
|
||||
help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen-url",
|
||||
type=str,
|
||||
default="http://0.0.0.0:8000/generate/",
|
||||
help="Base URL for the vLLM generation endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
help="Model to use for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="trl-lib/ultrafeedback-prompt",
|
||||
help="Dataset to use for training.",
|
||||
)
|
||||
Setup:
|
||||
|
||||
return parser.parse_args()
|
||||
```sh
|
||||
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
```
|
||||
|
||||
Usage (2 GPUs required):
|
||||
|
||||
# Spin up server
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
# Run training
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
|
||||
```
|
||||
"""
|
||||
|
||||
GEN_URL = "http://0.0.0.0:8000/generate/"
|
||||
ENV_URL = "http://0.0.0.0:8001"
|
||||
|
||||
print("⚡ Starting FastAPI server for Echo Environment...")
|
||||
# Workaround if you can't run the env with Docker
|
||||
work_dir = str(Path.cwd().parent.absolute())
|
||||
server_process = subprocess.Popen(
|
||||
[sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"],
|
||||
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
)
|
||||
|
||||
print("⏳ Waiting for server to start...")
|
||||
time.sleep(5)
|
||||
|
||||
try:
|
||||
response = requests.get(f"{ENV_URL}/health", timeout=2)
|
||||
print("\n✅ Echo Environment server is running!")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Server failed to start: {e}")
|
||||
print("\n📋 Checking error output...")
|
||||
server_process.poll()
|
||||
if server_process.stderr:
|
||||
stderr = server_process.stderr.read()
|
||||
if stderr:
|
||||
print(stderr)
|
||||
raise
|
||||
|
||||
|
||||
def start_env_server(env_host: str, env_port: int):
|
||||
"""Launch the Echo environment server locally."""
|
||||
env_url = f"http://{env_host}:{env_port}"
|
||||
print(f"⚡ Starting FastAPI server for Echo Environment on {env_url}...")
|
||||
|
||||
work_dir = str(Path.cwd().parent.absolute())
|
||||
process = subprocess.Popen(
|
||||
[sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", env_host, "--port", str(env_port)],
|
||||
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
)
|
||||
|
||||
print("⏳ Waiting for server to start...")
|
||||
time.sleep(5)
|
||||
|
||||
try:
|
||||
requests.get(f"{env_url}/health", timeout=2)
|
||||
print("\n✅ Echo Environment server is running!")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Server failed to start: {e}")
|
||||
if process.stderr:
|
||||
print(process.stderr.read())
|
||||
raise
|
||||
|
||||
return process
|
||||
# Create HTTP client for Echo Environment
|
||||
client = EchoEnv(base_url=f"{ENV_URL}")
|
||||
|
||||
|
||||
def rollout_func(
|
||||
prompts: list[str], args: GRPOConfig, processing_class, client: EchoEnv, gen_url: str
|
||||
) -> dict[str, list]:
|
||||
"""Generate completions via vLLM and compute environment rewards."""
|
||||
def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]:
|
||||
"""
|
||||
Custom rollout function that generates completions via vLLM server and computes environment rewards.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts to generate from
|
||||
args: GRPOConfig containing all sampling parameters
|
||||
processing_class: Tokenizer/processor for decoding completions
|
||||
|
||||
Returns:
|
||||
Dict containing prompt_ids, completion_ids, logprobs, and env_reward
|
||||
"""
|
||||
# 1. Generate completions via vLLM inference server (running on port 8000)
|
||||
payload = {
|
||||
"prompts": prompts,
|
||||
"n": args.num_generations,
|
||||
@ -135,80 +111,64 @@ def rollout_func(
|
||||
"max_tokens": args.max_completion_length,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}
|
||||
response = requests.post(GEN_URL, json=payload)
|
||||
|
||||
response = requests.post(gen_url, json=payload)
|
||||
if response.status_code != 200:
|
||||
print(f"Error response: {response.text}")
|
||||
response.raise_for_status()
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)
|
||||
|
||||
# 2. Step through the environment to get rewards
|
||||
env_result = client.reset()
|
||||
env_rewards = []
|
||||
for msg in completions_text:
|
||||
env_result = client.step(EchoAction(message=msg))
|
||||
env_rewards.append(env_result.reward)
|
||||
|
||||
# 3. Add environment rewards as extra field
|
||||
result["env_reward"] = env_rewards
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def reward_from_env(completions, **kwargs):
|
||||
"""Extract environment rewards for training."""
|
||||
"""Reward function that uses the environment reward."""
|
||||
# Extract environment rewards from kwargs (propagated via extra_fields)
|
||||
env_rewards = kwargs.get("env_reward", [])
|
||||
return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Select environment mode
|
||||
if args.env_mode == "local":
|
||||
env_url = f"http://{args.env_host}:{args.env_port}"
|
||||
server_process = start_env_server(args.env_host, args.env_port)
|
||||
elif args.env_mode == "docker":
|
||||
env_url = f"http://{args.env_host}:{args.env_port}"
|
||||
server_process = None
|
||||
print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
|
||||
elif args.env_mode == "space":
|
||||
env_url = args.env_host
|
||||
server_process = None
|
||||
print(f"🚀 Using Hugging Face Space environment at: {env_url}")
|
||||
if env_rewards:
|
||||
return [float(reward) for reward in env_rewards]
|
||||
else:
|
||||
raise ValueError(f"Unknown environment mode: {args.env_mode}")
|
||||
|
||||
gen_url = args.gen_url
|
||||
client = EchoEnv(base_url=env_url)
|
||||
dataset = load_dataset(args.dataset, split="train[:1000]")
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
|
||||
vllm_mode="server",
|
||||
use_vllm=True,
|
||||
logging_steps=1,
|
||||
report_to="trackio",
|
||||
num_train_epochs=1,
|
||||
max_completion_length=2048,
|
||||
gradient_accumulation_steps=4,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=args.model,
|
||||
reward_funcs=reward_from_env,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
rollout_func=lambda p, a, pc: rollout_func(p, a, pc, client, gen_url),
|
||||
callbacks=[RichProgressCallback()],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
time.sleep(5)
|
||||
|
||||
if server_process:
|
||||
print("🛑 Terminating Echo Environment server...")
|
||||
server_process.terminate()
|
||||
# Fallback if env_reward is not available
|
||||
return [0.0] * len(completions)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]")
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-GRPO-Rollout",
|
||||
vllm_mode="server",
|
||||
use_vllm=True,
|
||||
logging_steps=1,
|
||||
report_to="trackio",
|
||||
num_train_epochs=1,
|
||||
max_completion_length=2048,
|
||||
gradient_accumulation_steps=4,
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
reward_funcs=reward_from_env,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
rollout_func=rollout_func,
|
||||
callbacks=[RichProgressCallback()],
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Give time for background threads to finish
|
||||
time.sleep(5)
|
||||
|
||||
print("🛑 Terminating Echo Environment server...")
|
||||
server_process.terminate()
|
||||
|
||||
@ -13,33 +13,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple script to run GRPO training with OpenEnv's Wordle environment and a vLLM server.
|
||||
GRPO training for Wordle using TRL's `GRPOTrainer` and the TextArena OpenEnv environment.
|
||||
|
||||
Setup:
|
||||
Usage:
|
||||
# First, start the TextArena Wordle server (Docker or local):
|
||||
TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 \
|
||||
python -m src.envs.textarena_env.server.app
|
||||
|
||||
```sh
|
||||
uv pip install git+https://github.com/meta-pytorch/OpenEnv.git
|
||||
```
|
||||
# Start the vLLM server with your model
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
|
||||
Usage (2 GPUs required):
|
||||
|
||||
# Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
|
||||
```sh
|
||||
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
|
||||
# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app
|
||||
```
|
||||
|
||||
# Spin up vLLM server
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
# Run training
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
|
||||
```
|
||||
# Then run this training script:
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -81,12 +66,12 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
default="Qwen/Qwen3-1.7B",
|
||||
default="willcb/Qwen3-1.7B-Wordle",
|
||||
help="Model identifier passed to GRPOTrainer for fine-tuning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env-url",
|
||||
default="https://0.0.0.0:8001", # default="https://burtenshaw-textarena.hf.space"
|
||||
"--textarena-url",
|
||||
default="https://burtenshaw-textarena.hf.space",
|
||||
help="Base URL for the TextArena Wordle environment.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -520,7 +505,7 @@ def main() -> None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
env = TextArenaEnv(base_url=cli_args.env_url)
|
||||
env = TextArenaEnv(base_url=cli_args.textarena_url)
|
||||
|
||||
system_prompt = resolve_system_prompt(cli_args.system_prompt_path)
|
||||
|
||||
|
||||
@ -99,26 +99,6 @@ class TestVLLMClientServer(TrlTestCase):
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_chat(self):
|
||||
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
||||
outputs = self.client.chat(messages)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of sequences are equal to the number of messages
|
||||
assert len(prompt_ids) == len(messages)
|
||||
assert len(completion_ids) == len(messages)
|
||||
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_generate_with_params(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||
@ -200,26 +180,6 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_chat(self):
|
||||
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
||||
outputs = self.client.chat(messages)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of sequences are equal to the number of messages
|
||||
assert len(prompt_ids) == len(messages)
|
||||
assert len(completion_ids) == len(messages)
|
||||
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_generate_with_params(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||
@ -303,26 +263,6 @@ class TestVLLMClientServerTP(TrlTestCase):
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_chat(self):
|
||||
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
||||
outputs = self.client.chat(messages)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of sequences are equal to the number of messages
|
||||
assert len(prompt_ids) == len(messages)
|
||||
assert len(completion_ids) == len(messages)
|
||||
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_update_model_params(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
||||
self.client.update_model_params(model)
|
||||
@ -386,26 +326,6 @@ class TestVLLMClientServerDP(TrlTestCase):
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_chat(self):
|
||||
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
|
||||
outputs = self.client.chat(messages)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of sequences are equal to the number of messages
|
||||
assert len(prompt_ids) == len(messages)
|
||||
assert len(completion_ids) == len(messages)
|
||||
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_update_model_params(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
||||
self.client.update_model_params(model)
|
||||
|
||||
@ -19,7 +19,7 @@ from typing import Any
|
||||
import torch
|
||||
from accelerate.utils import gather_object
|
||||
|
||||
from ...data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages
|
||||
from ...data_utils import apply_chat_template, is_conversational
|
||||
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
|
||||
from ...trainer.utils import nanmax, nanmin, nanstd, pad
|
||||
|
||||
@ -81,17 +81,8 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
if images is not None and all(img_list == [] for img_list in images):
|
||||
images = None
|
||||
|
||||
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
||||
# [{"role": "user", "content": "What color is the sky?"}] to
|
||||
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
|
||||
if images is not None:
|
||||
prompts = [
|
||||
prepare_multimodal_messages(prompt, image_list)
|
||||
for prompt, image_list in zip(prompts, images, strict=True)
|
||||
]
|
||||
|
||||
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = (
|
||||
self._generate(prompts)
|
||||
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
|
||||
prompts, images
|
||||
)
|
||||
|
||||
# Convert lists of token IDs to padded tensors
|
||||
@ -290,15 +281,6 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
agg_completion_lengths = self.accelerator.gather(completion_lengths)
|
||||
num_items_in_batch = agg_completion_lengths.sum()
|
||||
|
||||
if sampling_per_token_logps is not None:
|
||||
sampling_per_token_logps = sampling_per_token_logps[local_input_indices_to_keep].contiguous()
|
||||
if old_per_token_logps is not None:
|
||||
old_per_token_logps = old_per_token_logps[local_input_indices_to_keep].contiguous()
|
||||
if ref_per_token_logps is not None:
|
||||
ref_per_token_logps = ref_per_token_logps[local_input_indices_to_keep].contiguous()
|
||||
if self.use_vllm and self.vllm_importance_sampling_correction:
|
||||
importance_sampling_ratio = importance_sampling_ratio[local_input_indices_to_keep].contiguous()
|
||||
|
||||
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
|
||||
for i, reward_func_name in enumerate(self.reward_func_names):
|
||||
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
|
||||
@ -327,10 +309,15 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
self._logs["advantages"].extend(all_process_advantages.tolist())
|
||||
|
||||
if images is not None:
|
||||
self._logs["images"].extend(all_images)
|
||||
self._logs["images"].extend(gather_object(images))
|
||||
|
||||
if self.use_vllm and self.vllm_importance_sampling_correction:
|
||||
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
|
||||
|
||||
if self.num_remains_in_group is not None and mode == "train":
|
||||
delta = delta[local_input_indices_to_keep]
|
||||
importance_sampling_ratio = importance_sampling_ratio[local_input_indices_to_keep]
|
||||
|
||||
delta = delta[completion_mask.bool()]
|
||||
mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
|
||||
max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import copy
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
@ -45,13 +44,6 @@ if is_vllm_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pil_to_base64(image):
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
img_bytes = buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
class VLLMClient:
|
||||
"""
|
||||
A client class to interact with a vLLM server.
|
||||
@ -192,7 +184,7 @@ class VLLMClient:
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
guided_decoding_regex: str | None = None,
|
||||
generation_kwargs: dict | None = None,
|
||||
) -> dict[str, list[list[int]]]:
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Generates model completions for the provided prompts.
|
||||
|
||||
@ -237,6 +229,12 @@ class VLLMClient:
|
||||
"""
|
||||
url = f"{self.base_url}/generate/"
|
||||
|
||||
def pil_to_base64(image):
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
img_bytes = buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
# Convert PIL images to base64 strings
|
||||
images = [pil_to_base64(img) for img in images] if images else None
|
||||
|
||||
@ -267,102 +265,6 @@ class VLLMClient:
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[list[dict]],
|
||||
n: int = 1,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
max_tokens: int = 16,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
guided_decoding_regex: str | None = None,
|
||||
generation_kwargs: dict | None = None,
|
||||
chat_template_kwargs: dict | None = None,
|
||||
) -> dict[str, list[list[int]]]:
|
||||
"""
|
||||
Generates model completions for the provided chat messages.
|
||||
|
||||
Args:
|
||||
messages (`list[list[dict]]`):
|
||||
List of message lists for which the model will generate completions. Each message is a dictionary with
|
||||
keys like "role" and "content".
|
||||
n (`int`, *optional*, defaults to `1`):
|
||||
Number of completions to generate for each message list.
|
||||
repetition_penalty (`float`, *optional*, defaults to `1.0`):
|
||||
Parameter for repetition penalty. 1.0 means no penalty.
|
||||
temperature (`float`, *optional*, defaults to `1.0`):
|
||||
Temperature parameter for sampling. Higher values increase diversity.
|
||||
top_p (`float`, *optional*, defaults to `1.0`):
|
||||
Top-p sampling parameter.`1.0` means no truncation.
|
||||
top_k (`int`, *optional*, defaults to `-1`):
|
||||
Top-k sampling parameter. `-1` means no truncation.
|
||||
min_p (`float`, *optional*, defaults to `0.0`):
|
||||
Minimum probability for sampling.
|
||||
max_tokens (`int`, *optional*, defaults to `16`):
|
||||
Maximum number of tokens to generate for each message list.
|
||||
truncate_prompt_tokens (`int`, *optional*):
|
||||
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
|
||||
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
|
||||
disabled.
|
||||
guided_decoding_regex (`str`, *optional*):
|
||||
Regular expression to guide the decoding process.
|
||||
generation_kwargs (`dict`, *optional*):
|
||||
Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like
|
||||
`seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they
|
||||
will override them.
|
||||
chat_template_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to customize the chat template used by the model.
|
||||
|
||||
Returns:
|
||||
`dict` with keys:
|
||||
- `prompt_ids` (`list[list[int]]`):
|
||||
List of lists of token IDs representing the tokenized input messages.
|
||||
- `completion_ids` (`list[list[int]]`):
|
||||
List of lists of token IDs representing the model-generated completions for each message list.
|
||||
- `logprobs` (`list[list[float]]`):
|
||||
List of lists of log probabilities for each generated token.
|
||||
"""
|
||||
url = f"{self.base_url}/chat/"
|
||||
|
||||
# Convert PIL images to base64 strings
|
||||
messages = copy.deepcopy(messages) # avoid modifying the original messages
|
||||
for message_list in messages:
|
||||
for message in message_list:
|
||||
if isinstance(message["content"], list):
|
||||
for part in message["content"]:
|
||||
if part["type"] == "image_pil":
|
||||
part["image_pil"] = pil_to_base64(part["image_pil"])
|
||||
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
"messages": messages,
|
||||
"n": n,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
"max_tokens": max_tokens,
|
||||
"truncate_prompt_tokens": truncate_prompt_tokens,
|
||||
"guided_decoding_regex": guided_decoding_regex,
|
||||
"generation_kwargs": generation_kwargs or {},
|
||||
"chat_template_kwargs": chat_template_kwargs or {},
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
json_response = response.json()
|
||||
return {
|
||||
"prompt_ids": json_response["prompt_ids"],
|
||||
"completion_ids": json_response["completion_ids"],
|
||||
"logprobs": json_response["logprobs"],
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
def init_communicator(self, device: torch.device | str | int = 0):
|
||||
"""
|
||||
Initializes the weight update group in a distributed setup for model synchronization.
|
||||
|
||||
@ -580,7 +580,7 @@ def main(script_args: ScriptArguments):
|
||||
"max_tokens": request.max_tokens,
|
||||
"truncate_prompt_tokens": request.truncate_prompt_tokens,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
|
||||
"logprobs": 0,
|
||||
}
|
||||
generation_kwargs.update(request.generation_kwargs)
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
@ -615,143 +615,6 @@ def main(script_args: ScriptArguments):
|
||||
]
|
||||
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[list[dict]]
|
||||
n: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
truncate_prompt_tokens: int | None = None
|
||||
guided_decoding_regex: str | None = None
|
||||
generation_kwargs: dict = field(default_factory=dict)
|
||||
chat_template_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
prompt_ids: list[list[int]]
|
||||
completion_ids: list[list[int]]
|
||||
logprobs: list[list[float]]
|
||||
|
||||
@app.post("/chat/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
"""
|
||||
Generates completions for the provided chat messages.
|
||||
|
||||
Args:
|
||||
request (`ChatRequest`):
|
||||
- `messages` (list of `dict`): A list of messages (dicts with "role" and "content" keys) for the model
|
||||
to generate completions.
|
||||
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
|
||||
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during
|
||||
generation.
|
||||
- `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead
|
||||
to more random outputs.
|
||||
- `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the
|
||||
diversity of the generated text.
|
||||
- `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables
|
||||
top-k sampling.
|
||||
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
|
||||
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
|
||||
completion.
|
||||
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
|
||||
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
|
||||
truncation). If set to `None`, truncation is disabled.
|
||||
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
|
||||
model will only generate tokens that match this regex pattern.
|
||||
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
|
||||
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
|
||||
keys that conflict with the other parameters, they will override them.
|
||||
- `chat_template_kwargs` (`dict`, *optional*): Additional keyword arguments to pass to the chat
|
||||
template.
|
||||
|
||||
Returns:
|
||||
`ChatResponse`:
|
||||
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
|
||||
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
|
||||
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
|
||||
generated completions.
|
||||
|
||||
Example request:
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:8000/chat/' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"messages": [[{ "role": "user", "content": "Hello!" }]]}'
|
||||
```
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"prompt_ids": [[151644, 872, 198, 9707, 0, 151645, 198, 151644, 77091, 198]],
|
||||
"completion_ids":[[151667, 198, 32313, 11, 279, 1196, 1101, 1053, 330, 9707, 8958, 773, 358, 1184, 311, 5889]],
|
||||
"logprobs": [[-0.00029404606902971864, -3.576278118089249e-07, -0.09024181962013245, -6.389413465512916e-05, -0.038671817630529404, -0.00013314791431184858, -0.5868351459503174, -0.09682723134756088, -0.06609706580638885, -0.00023803261865396053, -0.02242819033563137, -0.8185162544250488, -0.04954879730939865, -0.3169460594654083, -4.887569048150908e-06, -0.006023705471307039]]
|
||||
}
|
||||
```
|
||||
"""
|
||||
# Convert PIL images to base64 strings
|
||||
for message_list in request.messages:
|
||||
for message in message_list:
|
||||
if isinstance(message["content"], list):
|
||||
for part in message["content"]:
|
||||
if part["type"] == "image_pil":
|
||||
part["image_pil"] = Image.open(BytesIO(base64.b64decode(part["image_pil"])))
|
||||
|
||||
# Guided decoding, if enabled
|
||||
if request.guided_decoding_regex is not None:
|
||||
guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex)
|
||||
else:
|
||||
guided_decoding = None
|
||||
|
||||
generation_kwargs = {
|
||||
"n": request.n,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"truncate_prompt_tokens": request.truncate_prompt_tokens,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
|
||||
}
|
||||
generation_kwargs.update(request.generation_kwargs)
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
|
||||
# Evenly distribute prompts across DP ranks
|
||||
chunked_messages = chunk_list(request.messages, script_args.data_parallel_size)
|
||||
|
||||
# Send the messages to each worker
|
||||
for connection, messages in zip(connections, chunked_messages, strict=True):
|
||||
# When the number of messages is less than data_parallel_size, some workers will receive empty messages.
|
||||
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
|
||||
# with vLLM's requirement, and we later ignore the result.
|
||||
if not messages:
|
||||
messages = [[{"role": "user", "content": "<placeholder>"}]]
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
"chat_template_kwargs": request.chat_template_kwargs,
|
||||
}
|
||||
connection.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
||||
|
||||
# Receive results
|
||||
all_outputs = [connection.recv() for connection in connections]
|
||||
|
||||
# Handle empty prompts (see above)
|
||||
all_outputs = [output for output, prompts in zip(all_outputs, chunked_messages, strict=True) if prompts]
|
||||
|
||||
# Flatten and combine all results
|
||||
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
|
||||
prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
|
||||
logprobs: list[list[float]] = [
|
||||
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
|
||||
for outputs in all_outputs
|
||||
for output in outputs.outputs
|
||||
]
|
||||
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
|
||||
|
||||
class InitCommunicatorRequest(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
@ -24,7 +24,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
@ -1191,8 +1190,9 @@ class GRPOTrainer(BaseTrainer):
|
||||
)
|
||||
else:
|
||||
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
|
||||
# FIXME: this endpoint doesn't exist in vllm_client
|
||||
output = self.vllm_client.chat(
|
||||
messages=ordered_set_of_prompts,
|
||||
prompts=ordered_set_of_prompts,
|
||||
**sampling_params,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
)
|
||||
@ -1246,7 +1246,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
|
||||
"logprobs": 0, # only return the logprob of the generated token
|
||||
}
|
||||
if self.args.generation_kwargs is not None:
|
||||
generation_kwargs.update(self.args.generation_kwargs)
|
||||
@ -1922,45 +1922,45 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.num_completions_to_print,
|
||||
)
|
||||
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
|
||||
"prompt": self._logs["prompt"],
|
||||
"completion": self._logs["completion"],
|
||||
**self._logs["rewards"],
|
||||
"advantage": self._logs["advantages"],
|
||||
}
|
||||
logging_backends = []
|
||||
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
|
||||
logging_backends.append(wandb)
|
||||
if self.args.report_to and "trackio" in self.args.report_to:
|
||||
logging_backends.append(trackio)
|
||||
|
||||
df_base = pd.DataFrame(table)
|
||||
images_raw = self._logs["images"] or []
|
||||
if logging_backends:
|
||||
import pandas as pd
|
||||
|
||||
for logging_backend in self.args.report_to:
|
||||
if logging_backend == "wandb":
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
|
||||
"prompt": self._logs["prompt"],
|
||||
"completion": self._logs["completion"],
|
||||
**self._logs["rewards"],
|
||||
"advantage": self._logs["advantages"],
|
||||
}
|
||||
|
||||
df_base = pd.DataFrame(table)
|
||||
images_raw = self._logs["images"] or []
|
||||
|
||||
for logging_backend in logging_backends:
|
||||
if images_raw:
|
||||
images = []
|
||||
for image_list in self._logs["images"]:
|
||||
images.append([wandb.Image(image) for image in image_list])
|
||||
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
|
||||
# Convert images per backend and derive a dataframe that shares base columns
|
||||
if logging_backend is wandb:
|
||||
images = []
|
||||
for image_list in self._logs["images"]:
|
||||
images.append([wandb.Image(image) for image in image_list])
|
||||
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
|
||||
elif logging_backend is trackio:
|
||||
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
|
||||
logger.info("Skipping image logging for Trackio")
|
||||
df = df_base
|
||||
else:
|
||||
df = df_base
|
||||
|
||||
if self.wandb_log_unique_prompts:
|
||||
df = df.drop_duplicates(subset=["prompt"])
|
||||
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
if logging_backend == "trackio":
|
||||
if images_raw:
|
||||
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/334
|
||||
logger.info("Skipping image logging for Trackio")
|
||||
df = df_base
|
||||
# images = []
|
||||
# for image_list in self._logs["images"]:
|
||||
# images.append([trackio.Image(image) for image in image_list])
|
||||
# df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
|
||||
else:
|
||||
df = df_base
|
||||
|
||||
trackio.log({"completions": trackio.Table(dataframe=df)})
|
||||
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
|
||||
@ -24,7 +24,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
@ -1011,7 +1010,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
|
||||
output = self.vllm_client.chat(
|
||||
messages=ordered_set_of_prompts,
|
||||
prompts=ordered_set_of_prompts,
|
||||
**sampling_params,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
)
|
||||
@ -1531,45 +1530,45 @@ class RLOOTrainer(BaseTrainer):
|
||||
self.num_completions_to_print,
|
||||
)
|
||||
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
|
||||
"prompt": self._logs["prompt"],
|
||||
"completion": self._logs["completion"],
|
||||
**self._logs["rewards"],
|
||||
"advantage": self._logs["advantages"],
|
||||
}
|
||||
logging_backends = []
|
||||
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
|
||||
logging_backends.append(wandb)
|
||||
if self.args.report_to and "trackio" in self.args.report_to:
|
||||
logging_backends.append(trackio)
|
||||
|
||||
df_base = pd.DataFrame(table)
|
||||
images_raw = self._logs["images"] or []
|
||||
if logging_backends:
|
||||
import pandas as pd
|
||||
|
||||
for logging_backend in self.args.report_to:
|
||||
if logging_backend == "wandb":
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(self._logs["prompt"]),
|
||||
"prompt": self._logs["prompt"],
|
||||
"completion": self._logs["completion"],
|
||||
**self._logs["rewards"],
|
||||
"advantage": self._logs["advantages"],
|
||||
}
|
||||
|
||||
df_base = pd.DataFrame(table)
|
||||
images_raw = self._logs["images"] or []
|
||||
|
||||
for logging_backend in logging_backends:
|
||||
if images_raw:
|
||||
images = []
|
||||
for image_list in self._logs["images"]:
|
||||
images.append([wandb.Image(image) for image in image_list])
|
||||
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
|
||||
# Convert images per backend and derive a dataframe that shares base columns
|
||||
if logging_backend is wandb:
|
||||
images = []
|
||||
for image_list in self._logs["images"]:
|
||||
images.append([wandb.Image(image) for image in image_list])
|
||||
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
|
||||
elif logging_backend is trackio:
|
||||
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
|
||||
logger.info("Skipping image logging for Trackio")
|
||||
df = df_base
|
||||
else:
|
||||
df = df_base
|
||||
|
||||
if self.wandb_log_unique_prompts:
|
||||
df = df.drop_duplicates(subset=["prompt"])
|
||||
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
if logging_backend == "trackio":
|
||||
if images_raw:
|
||||
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/334
|
||||
logger.info("Skipping image logging for Trackio")
|
||||
df = df_base
|
||||
# images = []
|
||||
# for image_list in self._logs["images"]:
|
||||
# images.append([trackio.Image(image) for image in image_list])
|
||||
# df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
|
||||
else:
|
||||
df = df_base
|
||||
|
||||
trackio.log({"completions": trackio.Table(dataframe=df)})
|
||||
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
|
||||
Reference in New Issue
Block a user