Compare commits

...

132 Commits

Author SHA1 Message Date
7a2936e0a2 style 2025-10-18 00:38:17 +00:00
ba8b93831f rloo 2025-10-18 00:37:20 +00:00
c0c88071a3 fix style 2025-10-18 00:08:25 +00:00
fe11512100 dedup and some fixes 2025-10-18 00:02:48 +00:00
919ff5bced Merge branch 'main' into refactor_generate_5 2025-10-17 22:59:41 +00:00
e0eec055b4 🧺 [4/N] Refactor _generate in GRPO/RLOO: Move forward_kwargs outside generation method (#4154)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: YonatanGideoni <yonatan.gideoni@gmail.com>
Co-authored-by: burtenshaw <ben.burtenshaw@gmail.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-10-17 15:36:13 -06:00
f6e7c200c0 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-10-07 12:16:00 -06:00
a0ee1e635f Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-07 12:15:32 -06:00
45290c9cfc Merge branch 'main' into refactor_generate_3 2025-10-07 12:15:11 -06:00
5e4a026160 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-10-06 18:41:57 -06:00
ed54e2a1cb Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 18:34:31 -06:00
ee03478a14 remove test case for prompt truncation 2025-10-07 00:32:37 +00:00
e3c679c9c7 style 2025-10-06 23:59:17 +00:00
ddf3405c6c gfpo 2025-10-06 23:59:08 +00:00
2ce6c1ff41 token_type_ids and RLOO 2025-10-06 23:53:53 +00:00
34034e7f76 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 17:44:45 -06:00
a84325c73b style 2025-10-06 22:35:42 +00:00
cb1d4201f7 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-10-06 16:34:22 -06:00
2c012dca20 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 16:25:24 -06:00
db552be924 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-10-06 16:25:14 -06:00
4a274d5271 Merge branch 'main' into refactor_generate_2 2025-10-06 16:25:07 -06:00
ac2717f980 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 13:21:18 -06:00
766bbcefa0 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-10-06 13:19:59 -06:00
5b9a6ab7ae Merge branch 'main' into refactor_generate_2 2025-10-06 13:16:57 -06:00
df386f9667 Merge branch 'main' into refactor_generate_2 2025-10-06 10:02:54 -06:00
7f5b4995b6 Replace setup with pyproject and fix packaging unintended modules (#4194) 2025-10-06 15:56:32 +00:00
d258e36e45 Remove Optional from processing_class in PPOTrainer (#4212) 2025-10-06 15:56:32 +00:00
4fdaa4c672 Updated vLLM integration guide (#4162)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-06 15:56:31 +00:00
8319ce0b75 Replace unittest with pytest (#4188) 2025-10-06 15:56:29 +00:00
6543f51a9d Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-10-06 15:55:39 +00:00
ae2a0e71ad Remove tokenizer creation from sft example script (#4197) 2025-10-06 15:55:39 +00:00
5d34144b6f Remove custome_container for building the docs (#4198) 2025-10-06 15:55:38 +00:00
c1e7ad2696 [DOCS/FIX] lora without regrets - fix lr (#4207) 2025-10-06 15:55:38 +00:00
21a67fc43f [DOCS] Lora without regret (#4181)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-10-06 15:55:38 +00:00
648947911a Replace remaining trainer.tokenizer with trainer.processing_class in GRPO test (#4192) 2025-10-06 15:55:38 +00:00
f9c3c3c726 🌡️ Have vLLM return processed (temperature scaled) log probs (#4163)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-06 15:55:38 +00:00
cf9d8e76c4 Hotfix wrong formatting of docstrings with blockquote tips (#4187) 2025-10-06 15:55:38 +00:00
192deb3b2b Fix CI ImportError: FlashAttention2 and decorator order for all parameterized tests (#4176) 2025-10-06 15:55:38 +00:00
e82db740f0 🔣 Fix test: replace trainer.tokenizer by trainer.processing_class (#4185) 2025-10-06 15:55:38 +00:00
d599c207cd Merge branch 'main' into refactor_generate_2 2025-10-01 08:49:04 -06:00
377b0811c9 rm test_training_vlm_and_prompt_truncation 2025-10-01 02:28:38 +00:00
c434fa23bf truncation_side=left 2025-10-01 02:28:07 +00:00
ddfd3b58c9 same for rloo 2025-10-01 02:14:43 +00:00
4dce145d40 remove vision tokens 2025-10-01 01:09:40 +00:00
5cc6af57a5 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-09-30 19:00:51 -06:00
5fca5b8802 fix normal generation path 2025-10-01 00:46:15 +00:00
49577adb19 Same for RLOO 2025-10-01 00:17:37 +00:00
e164ec5aab repicate all_prompt_ids 2025-10-01 00:11:48 +00:00
e7aa945273 fix vllm client server 2025-09-30 23:10:16 +00:00
f11759e66d Merge branch 'main' into refactor_generate_2 2025-09-30 16:29:59 -06:00
a01b9caf81 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-09-26 19:34:32 -06:00
b0e02795e2 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-09-26 19:34:15 -06:00
3f02702600 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-09-26 19:34:11 -06:00
4b9c1262a9 Merge branch 'refactor_generate' into refactor_generate_2 2025-09-26 19:34:00 -06:00
e82bfb4264 Merge branch 'main' into refactor_generate 2025-09-26 19:33:52 -06:00
effb41ba5d Merge branch 'main' into refactor_generate 2025-09-26 19:12:04 -06:00
c5064d61ea gfpo 2025-09-27 00:04:17 +00:00
7b7a11d833 test and doc 2025-09-27 00:00:52 +00:00
b8c0c9b219 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-09-26 17:49:26 -06:00
c8041e1ccc Merge branch 'refactor_generate' into refactor_generate_2 2025-09-26 17:48:06 -06:00
55a2480195 rloo + doc 2025-09-26 23:46:50 +00:00
15c6620c84 refactor: update prepare_multimodal_messages to accept images directly and enhance handling of structured messages 2025-09-26 23:32:38 +00:00
48a1c30e7e don't re-prepare data 2025-09-26 22:20:23 +00:00
9925199ee9 move forward_kwargs outside of generate 2025-09-26 22:14:58 +00:00
8149d0578f rm truncation test 2025-09-26 21:27:47 +00:00
35f99fd867 requires padding 2025-09-26 21:27:33 +00:00
fc263a309a rm imports 2025-09-26 20:01:37 +00:00
d8af0039fa rm useless comment 2025-09-26 19:59:12 +00:00
0b5865e8f5 ensure proper truncation and side 2025-09-26 19:57:23 +00:00
acee7d817f rm truncate_with_protected_tokens 2025-09-26 19:45:09 +00:00
11acc758c2 rm enforce eager 2025-09-26 19:43:45 +00:00
46d8eb79cf revert 2025-09-26 19:43:17 +00:00
0e2ae34a93 rely on generator for prompt truncation 2025-09-26 19:41:24 +00:00
e770efeede Merge branch 'refactor_generate' into refactor_generate_2 2025-09-26 12:57:02 -06:00
8d34d546bb remove pad token removal 2025-09-26 18:56:45 +00:00
d79b9e1c8f get prompt ids from generation 2025-09-26 18:41:51 +00:00
b3bd0b05d4 another one 2025-09-26 18:05:49 +00:00
9da4830c53 simplify a bit + comment 2025-09-26 16:22:44 +00:00
236b78b455 better 2025-09-26 16:14:18 +00:00
8766fa5cc0 consistent naming 2025-09-26 16:12:07 +00:00
53772ef7b8 getting closer 2025-09-26 16:02:03 +00:00
27dc9585a0 fix num_input_tokens_seen 2025-09-26 03:09:42 +00:00
3d8ea27c68 wrong merge commit 2025-09-26 02:54:26 +00:00
d3f1d3c801 Merge branch 'main' into refactor_generate 2025-09-25 20:51:09 -06:00
9435a9400f refactor in grpo 2025-09-26 02:48:11 +00:00
2dc69a68e0 Merge branch 'main' into generate-method 2025-09-25 18:01:23 -06:00
1a66b431d0 revert chage data utils 2025-09-25 23:57:14 +00:00
c1ae6aa787 back to working point 2025-09-25 23:56:11 +00:00
8b3a724602 progress again again 2025-09-25 23:27:53 +00:00
0213662cd4 progress continues 2025-09-25 18:24:46 +00:00
ebe32c26d8 progress 2025-09-25 06:14:02 +00:00
b0dceb97ac restart 2025-09-25 04:03:39 +00:00
b4cadde233 Merge branch 'main' into generate-method 2025-09-24 13:57:42 -06:00
ec6ad259d2 nits style and align 2025-09-24 17:26:25 +00:00
c83e710831 same for rloo 2025-09-24 17:17:14 +00:00
cdb4c76a3f Merge branch 'main' into generate-method 2025-09-24 10:09:25 -06:00
365d5017f4 Merge branch 'main' into generate-method 2025-09-23 08:55:43 -06:00
d8665e1236 Merge branch 'main' into generate-method 2025-09-22 20:21:14 -06:00
a6a8c448a0 Merge branch 'main' into generate-method 2025-09-22 18:19:32 -06:00
c5004406ff Merge branch 'multi-image-support' into generate-method 2025-09-22 18:08:02 -06:00
9b6652eed4 rm VLM x RM warning 2025-09-23 00:05:23 +00:00
1c53094868 clarify image column desc 2025-09-22 23:57:13 +00:00
05270f820f update layers to ignore 2025-09-22 23:51:57 +00:00
485781cb3e Merge branch 'main' into multi-image-support 2025-09-22 17:47:19 -06:00
562c662c2b Merge branch 'main' into multi-image-support 2025-09-22 16:42:28 -06:00
efbb03a0d6 Merge branch 'drop-image_split_sizes' into multi-image-support 2025-09-22 16:20:42 -06:00
e17ec42797 Merge branch 'main' into drop-image_split_sizes 2025-09-22 16:17:57 -06:00
d3a769fe8f fix doc 2025-09-20 17:15:13 +00:00
b628744752 rm vllm 2025-09-20 17:15:02 +00:00
4fc2b5b71d gfpo 2025-09-20 17:13:23 +00:00
4d12aebf33 Merge branch 'multi-image-support' into generate-method 2025-09-20 10:53:36 -06:00
fc52e6832d test fixed! 2025-09-20 16:26:34 +00:00
dfc0d388ab Merge branch 'drop-image_split_sizes' into multi-image-support 2025-09-20 09:52:06 -06:00
52d8bd91b0 Merge branch 'main' into drop-image_split_sizes 2025-09-20 09:51:51 -06:00
fa738768c6 skip failing test 2025-09-20 15:21:36 +00:00
f998432622 debug 2025-09-20 05:18:40 +00:00
ae1f497959 generate method 2025-09-20 05:08:48 +00:00
fc6b11fcae update test 2025-09-20 04:22:54 +00:00
529add673c oops 2025-09-20 03:55:03 +00:00
099a39bd6a peft rloo 2025-09-20 03:04:07 +00:00
1257796ba8 rloo test 2025-09-20 03:01:47 +00:00
f4c82bfc04 fix gfpo 2025-09-20 02:55:59 +00:00
d2adc63eb6 test peft 2025-09-20 02:52:33 +00:00
088897b9cd fix 2025-09-20 02:25:10 +00:00
86cc30bf3c gfpo 2025-09-20 00:43:43 +00:00
30ad7ca371 rloo 2025-09-20 00:37:54 +00:00
dcf4b92da0 no vlm reward models 2025-09-20 00:18:18 +00:00
3ca6ad5003 log with wandb 2025-09-19 23:31:06 +00:00
229c554929 multi-image grpo 2025-09-19 22:45:57 +00:00
c8933aa856 gfpo 2025-09-19 21:10:06 +00:00
449ef07919 simpler 2025-09-19 21:05:47 +00:00
552e899015 Refactor image handling: replace image_split_sizes with image_grid_thw in GRPO and RLOO trainers; update split_pixel_values_by_grid to use image_grid_thw 2025-09-19 20:57:51 +00:00
6 changed files with 384 additions and 272 deletions

View File

@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import itertools
import textwrap
from time import strftime
from datasets import Dataset, DatasetDict
from parameterized import parameterized
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
from trl.data_utils import (
@ -46,30 +46,46 @@ class TestPrepareMultimodalMessages:
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (32, 32), color="red")
messages = prepare_multimodal_messages(messages, images=[image])
expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]
assert messages == expected
def test_first_user_message_gets_image(self):
"""Test that only the first user message gets an image placeholder."""
"""Test that only the first user message gets an image."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
{"role": "user", "content": "How about the grass?"},
]
prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (32, 32), color="red")
messages = prepare_multimodal_messages(messages, images=[image])
expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "How about the grass?"}],
},
]
assert messages == expected
@ -80,20 +96,23 @@ class TestPrepareMultimodalMessages:
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
prepare_multimodal_messages(messages, num_images=3)
images = [Image.new("RGB", (32, 32), color=color) for color in ["red", "green", "blue"]]
messages = prepare_multimodal_messages(messages, images=images)
expected = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "image", "image": images[0]},
{"type": "image", "image": images[1]},
{"type": "image", "image": images[2]},
{"type": "text", "text": "What color is the sky?"},
],
},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]
assert messages == expected
@ -105,11 +124,18 @@ class TestPrepareMultimodalMessages:
{"role": "user", "content": "What color is the sky?"},
]
prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (32, 32), color="red")
messages = prepare_multimodal_messages(messages, images=[image])
expected = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant"}],
},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
]
assert messages == expected
@ -122,10 +148,25 @@ class TestPrepareMultimodalMessages:
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]
original = copy.deepcopy(messages)
prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (32, 32), color="red")
messages = prepare_multimodal_messages(messages, images=[image])
assert messages == original
expected = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant"}],
},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]
assert messages == expected
def test_mixed_prepared_and_unprepared_messages(self):
"""Test handling of mixed prepared and unprepared messages."""
@ -135,12 +176,22 @@ class TestPrepareMultimodalMessages:
{"role": "user", "content": "What about the grass?"},
]
prepare_multimodal_messages(messages, num_images=1)
image = Image.new("RGB", (32, 32), color="red")
messages = prepare_multimodal_messages(messages, images=[image])
expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant"}],
},
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is blue."}],
},
]
assert messages == expected

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import defaultdict, deque
from collections.abc import Sequence
from itertools import takewhile
@ -28,19 +29,30 @@ from transformers import PreTrainedTokenizerBase, ProcessorMixin
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)
def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None:
def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> list[dict[str, Any]]:
# docstyle-ignore # because <Image> is not parsable in the code block
"""
Convert messages into a structured multimodal format if needed.
Each message's content is transformed from a raw string into a list of typed parts. The first user message is
prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries.
Convert messages into a structured multimodal format and inject the provided images into the message contents.
Args:
messages (`list[dict[str, Any]]`):
Messages with `"role"` and `"content"`. Content may be a raw string before transformation.
num_images (`int`):
Number of images to include in the first user message. This is used to determine how many image
placeholders to add.
Messages with `"role"` and `"content"`. Content may be a raw string before transformation. List of messages
a `"role"` key (`"system"`, `"user"`, or `"assistant"`) and a `"content"` key containing either a string or
a list of structured blocks if already prepared.
images (`list`):
List of image objects to insert.
Returns:
`list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured
content blocks, and all `"image"` placeholders are populated with the corresponding image objects.
Notes:
- When the input `messages` isn't already in the structured format, (i.e., all `"content"` values are strings),
the function transforms them into the structured format by wrapping text in `{"type": "text", "text": ...}`
and inserting `{"type": "image"}` placeholders for the images *before* the first user message.
- When the input `messages` is already in the structured format (i.e., all `"content"` values are lists of
structured blocks), the function only fills in the actual images in the existing `{"type": "image"}`
placeholders. If the number of placeholders does not match the number of provided images, an error is raised.
Example:
```python
@ -50,24 +62,28 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int)
{"role": "assistant", "content": "It looks like a cat."},
]
# Output (num_images=1)
# Output, one image provided
[
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]},
{"role": "user", "content": [{"type": "image", "image": <PIL.Image.Image>}, {"type": "text", "text": "What's in this image?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]},
]
```
"""
image_included = False
messages = copy.deepcopy(messages) # avoid modifying the original messages
# First, convert all messages to the structured format if needed, and insert image placeholders if needed
images_included = False
for message in messages:
if message["role"] == "system":
if isinstance(message["content"], str): # if already prepared, the content will be a list
message["content"] = [{"type": "text", "text": message["content"]}]
elif message["role"] == "user":
if isinstance(message["content"], str) and not image_included:
placeholders = [{"type": "image"}] * num_images
message["content"] = [*placeholders, {"type": "text", "text": message["content"]}]
image_included = True
elif isinstance(message["content"], str) and image_included:
if isinstance(message["content"], str) and not images_included:
image_entries = [{"type": "image"}] * len(images)
message["content"] = [*image_entries, {"type": "text", "text": message["content"]}]
images_included = True
elif isinstance(message["content"], str) and images_included:
message["content"] = [{"type": "text", "text": message["content"]}]
elif message["role"] == "assistant":
if isinstance(message["content"], str):
@ -75,6 +91,55 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int)
else:
raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.")
# Then, check that the number of image placeholders matches the number of images provided
num_placeholders = sum(sum(1 for part in message["content"] if part["type"] == "image") for message in messages)
if num_placeholders != len(images):
raise ValueError(
f"Number of images provided ({len(images)}) does not match number of image placeholders ({num_placeholders})."
)
# Then, fill in the actual images in the placeholders
img_idx = 0
for message in messages:
for part in message["content"]:
if part["type"] == "image":
part["image"] = images[img_idx]
img_idx += 1
return messages
def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
# docstyle-ignore # because <Image> is not parsable in the code block
"""
Convert structured multimodal messages into a format compatible with vLLM. Replaces `"type": "image"` blocks with
`"type": "image_pil"` blocks, and `"image": Image` with `"image_pil": Image`.
Args:
messages (`list[dict[str, Any]]`):
Messages with `"role"` and `"content"`. Content is expected to be a list of structured blocks.
Returns:
`list[dict[str, Any]]`:
A deep-copied list of messages compatible with vLLM's expected input format.
Example:
```python
# Input
[{"role": "user", "content": [{"type": "image", "image": <PIL.Image.Image>}, {"type": "text", "text": "What's in this image?"}]}]
# Output
[{"role": "user", "content": [{"type": "image_pil", "image_pil": <PIL.Image.Image>}, {"type": "text", "text": "What's in this image?"}]}]
```
"""
messages = copy.deepcopy(messages) # avoid modifying the original messages
for message in messages:
for part in message["content"]:
if part["type"] == "image":
part["type"] = "image_pil" # vLLM expects 'image_pil' key for images
part["image_pil"] = part.pop("image")
return messages
def is_conversational(example: dict[str, Any]) -> bool:
r"""

View File

@ -18,7 +18,7 @@ from typing import Any, Callable
import torch
from accelerate.utils import gather_object
from ...data_utils import is_conversational
from ...data_utils import apply_chat_template, is_conversational
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
from ...trainer.utils import nanmax, nanmin, nanstd, pad
@ -80,13 +80,9 @@ class GFPOTrainer(_GRPOTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
(
prompt_ids_list,
completion_ids_list,
num_items_in_batch,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
prompts, images
)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -112,6 +108,23 @@ class GFPOTrainer(_GRPOTrainer):
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
@ -119,11 +132,6 @@ class GFPOTrainer(_GRPOTrainer):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the

View File

@ -44,9 +44,14 @@ from transformers import (
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
from ..data_utils import (
apply_chat_template,
is_conversational,
prepare_multimodal_messages,
prepare_multimodal_messages_vllm,
)
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
@ -1069,30 +1074,9 @@ class GRPOTrainer(BaseTrainer):
rewards_per_func = gather(rewards_per_func)
return rewards_per_func
def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
def _generate_single_turn(self, prompts: list):
device = self.accelerator.device
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
if images is not None:
kwargs = {"images": images}
for prompt, image_list in zip(prompts, images):
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=len(image_list))
prompts_text = [
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
if images is not None:
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
@ -1105,38 +1089,35 @@ class GRPOTrainer(BaseTrainer):
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
if images is not None:
all_images = gather_object(images)
all_prompts = gather_object(prompts)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
if images is not None:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
ordered_set_of_prompts = all_prompts[:: self.num_generations]
sampling_params = {
"n": self.num_generations,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding_regex": self.guided_decoding_regex,
"generation_kwargs": self.args.generation_kwargs,
}
with profiling_context(self, "vLLM.generate"):
output = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
else:
payload = None
@ -1183,31 +1164,18 @@ class GRPOTrainer(BaseTrainer):
if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
if images is not None:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
else:
all_images = None
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts_text = prompts_text
all_images = images
if images is not None and all_images:
vllm_inputs = []
for prompt, image_list in zip(all_prompts_text, all_images):
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
else:
vllm_inputs = all_prompts_text
all_prompts = prompts
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False)
else:
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
@ -1234,15 +1202,20 @@ class GRPOTrainer(BaseTrainer):
self.llm.sleep(level=1)
elif self.use_transformers_paged:
# Re-process inputs for paged generation if needed
# Note: images are already validated and preprocessed above
paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
previous_attn = self.model_wrapped.config._attn_implementation
if is_flash_attn_2_available():
self.model_wrapped.config._attn_implementation = "paged_attention"
processor_kwargs = {
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
)
else:
self.model_wrapped.config._attn_implementation = "sdpa_paged"
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
generate_inputs["inputs"] = generate_inputs.pop("input_ids")
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
@ -1258,27 +1231,30 @@ class GRPOTrainer(BaseTrainer):
unwrapped_model.to(torch.float16)
with torch.inference_mode():
all_outputs = unwrapped_model.generate_batch(
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
**generate_inputs, generation_config=self.generation_config, progress_bar=False
)
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
completion_ids = [output.generated_tokens for output in all_outputs.values()]
prompt_ids = paged_prompt_inputs.input_ids
# Restore the original attention implementation, training mode
self.model_wrapped.config._attn_implementation = previous_attn
prompt_ids = generate_inputs["inputs"]
logprobs = None # not used in this case
else:
# Regular generation path
generate_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
max_length=self.max_prompt_length,
truncation=True,
add_special_tokens=False,
**kwargs,
)
processor_kwargs = {
"return_tensors": "pt",
"padding": True,
"padding_side": "left",
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
generate_inputs = super()._prepare_inputs(generate_inputs)
with (
@ -1307,13 +1283,13 @@ class GRPOTrainer(BaseTrainer):
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
logprobs = None # not used in this case
return prompt_ids, completion_ids, logprobs, forward_kwargs
return prompt_ids, completion_ids, logprobs
def _generate(self, prompts: list[str], images: Optional[list]):
def _generate(self, prompts: list[str]):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts)
# Get completion length per sequence, used for logging
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
@ -1345,7 +1321,7 @@ class GRPOTrainer(BaseTrainer):
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
return prompt_ids, completion_ids, total_completion_tokens, logprobs
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@ -1365,13 +1341,15 @@ class GRPOTrainer(BaseTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
(
prompt_ids_list,
completion_ids_list,
num_items_in_batch,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
if images is not None:
prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)]
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
prompts
)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -1397,6 +1375,23 @@ class GRPOTrainer(BaseTrainer):
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
@ -1404,11 +1399,6 @@ class GRPOTrainer(BaseTrainer):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the

View File

@ -45,9 +45,14 @@ from transformers import (
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
from ..data_utils import (
apply_chat_template,
is_conversational,
prepare_multimodal_messages,
prepare_multimodal_messages_vllm,
)
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_vllm_available
@ -1065,30 +1070,9 @@ class RLOOTrainer(BaseTrainer):
rewards_per_func = gather(rewards_per_func)
return rewards_per_func
def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
def _generate_single_turn(self, prompts: list):
device = self.accelerator.device
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
if images is not None:
kwargs = {"images": images}
for prompt, image_list in zip(prompts, images):
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=len(image_list))
prompts_text = [
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
if images is not None:
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
@ -1101,38 +1085,35 @@ class RLOOTrainer(BaseTrainer):
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
if images is not None:
all_images = gather_object(images)
all_prompts = gather_object(prompts)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
if images is not None:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
ordered_set_of_prompts = all_prompts[:: self.num_generations]
sampling_params = {
"n": self.num_generations,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding_regex": self.guided_decoding_regex,
"generation_kwargs": self.args.generation_kwargs,
}
with profiling_context(self, "vLLM.generate"):
output = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
else:
payload = None
@ -1177,31 +1158,18 @@ class RLOOTrainer(BaseTrainer):
if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
if images is not None:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
else:
all_images = None
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts_text = prompts_text
all_images = images
if images is not None and all_images:
vllm_inputs = []
for prompt, image_list in zip(all_prompts_text, all_images):
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
else:
vllm_inputs = all_prompts_text
all_prompts = prompts
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False)
else:
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
@ -1221,15 +1189,20 @@ class RLOOTrainer(BaseTrainer):
self.llm.sleep(level=1)
elif self.use_transformers_paged:
# Re-process inputs for paged generation if needed
# Note: images are already validated and preprocessed above
paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
previous_attn = self.model_wrapped.config._attn_implementation
if is_flash_attn_2_available():
self.model_wrapped.config._attn_implementation = "paged_attention"
processor_kwargs = {
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
)
else:
self.model_wrapped.config._attn_implementation = "sdpa_paged"
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
generate_inputs["inputs"] = generate_inputs.pop("input_ids")
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
@ -1245,26 +1218,29 @@ class RLOOTrainer(BaseTrainer):
unwrapped_model.to(torch.float16)
with torch.inference_mode():
all_outputs = unwrapped_model.generate_batch(
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
**generate_inputs, generation_config=self.generation_config, progress_bar=False
)
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
completion_ids = [output.generated_tokens for output in all_outputs.values()]
prompt_ids = paged_prompt_inputs.input_ids
# Restore the original attention implementation, training mode
self.model_wrapped.config._attn_implementation = previous_attn
prompt_ids = generate_inputs["inputs"]
else:
# Regular generation path
generate_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
max_length=self.max_prompt_length,
truncation=True,
add_special_tokens=False,
**kwargs,
)
processor_kwargs = {
"return_tensors": "pt",
"padding": True,
"padding_side": "left",
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
generate_inputs = super()._prepare_inputs(generate_inputs)
with (
@ -1292,13 +1268,13 @@ class RLOOTrainer(BaseTrainer):
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
return prompt_ids, completion_ids, forward_kwargs
return prompt_ids, completion_ids
def _generate(self, prompts: list[str], images: Optional[list]):
def _generate(self, prompts: list[str]):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images)
prompt_ids, completion_ids = self._generate_single_turn(prompts)
# Get completion length per sequence, used for logging
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
@ -1331,7 +1307,7 @@ class RLOOTrainer(BaseTrainer):
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
return prompt_ids, completion_ids, forward_kwargs
return prompt_ids, completion_ids
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@ -1351,7 +1327,13 @@ class RLOOTrainer(BaseTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
if images is not None:
prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)]
prompt_ids_list, completion_ids_list = self._generate(prompts)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -1372,6 +1354,23 @@ class RLOOTrainer(BaseTrainer):
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
@ -1379,11 +1378,6 @@ class RLOOTrainer(BaseTrainer):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
with torch.no_grad():
# Compute the per-token log probabilities for the current model
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(

View File

@ -353,9 +353,7 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
images = None
if "messages" in examples[0]: # conversational case
for example in examples:
prepare_multimodal_messages(example["messages"], len(example["images"]))
messages = [example["messages"] for example in examples]
messages = [prepare_multimodal_messages(example["messages"], example["images"]) for example in examples]
texts = self.processor.apply_chat_template(messages)
elif self.dataset_text_field in examples[0]: # standard case
texts = [example[self.dataset_text_field] for example in examples]
@ -396,7 +394,8 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
images = None
if is_conversational(examples[0]): # conversational case
for example in examples:
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"])
example["completion"] = prepare_multimodal_messages(example["completion"], images=[])
examples = [apply_chat_template(example, self.processor) for example in examples]
prompts = [example["prompt"] for example in examples]
@ -951,10 +950,13 @@ class SFTTrainer(BaseTrainer):
output = {}
if is_conversational(example):
if self._is_vlm:
prepare_multimodal_messages(example["prompt"], num_images=0)
prepare_multimodal_messages(example["completion"], num_images=0)
prompt = prepare_multimodal_messages(example["prompt"], images=[])
completion = prepare_multimodal_messages(example["completion"], images=[])
else:
prompt = example["prompt"]
completion = example["completion"]
prompt_ids = processing_class.apply_chat_template(
example["prompt"],
prompt,
tokenize=True,
add_generation_prompt=True,
tools=example.get("tools"),
@ -964,7 +966,7 @@ class SFTTrainer(BaseTrainer):
# even for single examples, while for LLMs it returns lists of ints.
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
prompt_completion_processed = processing_class.apply_chat_template(
example["prompt"] + example["completion"],
prompt + completion,
return_dict=True,
tokenize=True,
return_assistant_tokens_mask=assistant_only_loss,
@ -1002,9 +1004,11 @@ class SFTTrainer(BaseTrainer):
else: # language modeling case
if is_conversational(example):
if self._is_vlm:
prepare_multimodal_messages(example["messages"], num_images=0)
messages = prepare_multimodal_messages(example["messages"], images=[])
else:
messages = example["messages"]
processed = processing_class.apply_chat_template(
example["messages"],
messages,
return_dict=True,
tokenize=True,
return_assistant_tokens_mask=assistant_only_loss,