mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
132 Commits
activation
...
7a2936e0a2
Author | SHA1 | Date | |
---|---|---|---|
7a2936e0a2 | |||
ba8b93831f | |||
c0c88071a3 | |||
fe11512100 | |||
919ff5bced | |||
e0eec055b4 | |||
f6e7c200c0 | |||
a0ee1e635f | |||
45290c9cfc | |||
5e4a026160 | |||
ed54e2a1cb | |||
ee03478a14 | |||
e3c679c9c7 | |||
ddf3405c6c | |||
2ce6c1ff41 | |||
34034e7f76 | |||
a84325c73b | |||
cb1d4201f7 | |||
2c012dca20 | |||
db552be924 | |||
4a274d5271 | |||
ac2717f980 | |||
766bbcefa0 | |||
5b9a6ab7ae | |||
df386f9667 | |||
7f5b4995b6 | |||
d258e36e45 | |||
4fdaa4c672 | |||
8319ce0b75 | |||
6543f51a9d | |||
ae2a0e71ad | |||
5d34144b6f | |||
c1e7ad2696 | |||
21a67fc43f | |||
648947911a | |||
f9c3c3c726 | |||
cf9d8e76c4 | |||
192deb3b2b | |||
e82db740f0 | |||
d599c207cd | |||
377b0811c9 | |||
c434fa23bf | |||
ddfd3b58c9 | |||
4dce145d40 | |||
5cc6af57a5 | |||
5fca5b8802 | |||
49577adb19 | |||
e164ec5aab | |||
e7aa945273 | |||
f11759e66d | |||
a01b9caf81 | |||
b0e02795e2 | |||
3f02702600 | |||
4b9c1262a9 | |||
e82bfb4264 | |||
effb41ba5d | |||
c5064d61ea | |||
7b7a11d833 | |||
b8c0c9b219 | |||
c8041e1ccc | |||
55a2480195 | |||
15c6620c84 | |||
48a1c30e7e | |||
9925199ee9 | |||
8149d0578f | |||
35f99fd867 | |||
fc263a309a | |||
d8af0039fa | |||
0b5865e8f5 | |||
acee7d817f | |||
11acc758c2 | |||
46d8eb79cf | |||
0e2ae34a93 | |||
e770efeede | |||
8d34d546bb | |||
d79b9e1c8f | |||
b3bd0b05d4 | |||
9da4830c53 | |||
236b78b455 | |||
8766fa5cc0 | |||
53772ef7b8 | |||
27dc9585a0 | |||
3d8ea27c68 | |||
d3f1d3c801 | |||
9435a9400f | |||
2dc69a68e0 | |||
1a66b431d0 | |||
c1ae6aa787 | |||
8b3a724602 | |||
0213662cd4 | |||
ebe32c26d8 | |||
b0dceb97ac | |||
b4cadde233 | |||
ec6ad259d2 | |||
c83e710831 | |||
cdb4c76a3f | |||
365d5017f4 | |||
d8665e1236 | |||
a6a8c448a0 | |||
c5004406ff | |||
9b6652eed4 | |||
1c53094868 | |||
05270f820f | |||
485781cb3e | |||
562c662c2b | |||
efbb03a0d6 | |||
e17ec42797 | |||
d3a769fe8f | |||
b628744752 | |||
4fc2b5b71d | |||
4d12aebf33 | |||
fc52e6832d | |||
dfc0d388ab | |||
52d8bd91b0 | |||
fa738768c6 | |||
f998432622 | |||
ae1f497959 | |||
fc6b11fcae | |||
529add673c | |||
099a39bd6a | |||
1257796ba8 | |||
f4c82bfc04 | |||
d2adc63eb6 | |||
088897b9cd | |||
86cc30bf3c | |||
30ad7ca371 | |||
dcf4b92da0 | |||
3ca6ad5003 | |||
229c554929 | |||
c8933aa856 | |||
449ef07919 | |||
552e899015 |
@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
from time import strftime
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
from trl.data_utils import (
|
||||
@ -46,30 +46,46 @@ class TestPrepareMultimodalMessages:
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
image = Image.new("RGB", (32, 32), color="red")
|
||||
messages = prepare_multimodal_messages(messages, images=[image])
|
||||
|
||||
expected = [
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is blue."}],
|
||||
},
|
||||
]
|
||||
|
||||
assert messages == expected
|
||||
|
||||
def test_first_user_message_gets_image(self):
|
||||
"""Test that only the first user message gets an image placeholder."""
|
||||
"""Test that only the first user message gets an image."""
|
||||
messages = [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
{"role": "user", "content": "How about the grass?"},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
image = Image.new("RGB", (32, 32), color="red")
|
||||
messages = prepare_multimodal_messages(messages, images=[image])
|
||||
|
||||
expected = [
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is blue."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "How about the grass?"}],
|
||||
},
|
||||
]
|
||||
|
||||
assert messages == expected
|
||||
@ -80,20 +96,23 @@ class TestPrepareMultimodalMessages:
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=3)
|
||||
images = [Image.new("RGB", (32, 32), color=color) for color in ["red", "green", "blue"]]
|
||||
messages = prepare_multimodal_messages(messages, images=images)
|
||||
|
||||
expected = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "image", "image": images[0]},
|
||||
{"type": "image", "image": images[1]},
|
||||
{"type": "image", "image": images[2]},
|
||||
{"type": "text", "text": "What color is the sky?"},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is blue."}],
|
||||
},
|
||||
]
|
||||
|
||||
assert messages == expected
|
||||
@ -105,11 +124,18 @@ class TestPrepareMultimodalMessages:
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
image = Image.new("RGB", (32, 32), color="red")
|
||||
messages = prepare_multimodal_messages(messages, images=[image])
|
||||
|
||||
expected = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant"}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
||||
},
|
||||
]
|
||||
|
||||
assert messages == expected
|
||||
@ -122,10 +148,25 @@ class TestPrepareMultimodalMessages:
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
]
|
||||
|
||||
original = copy.deepcopy(messages)
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
image = Image.new("RGB", (32, 32), color="red")
|
||||
messages = prepare_multimodal_messages(messages, images=[image])
|
||||
|
||||
assert messages == original
|
||||
expected = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant"}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is blue."}],
|
||||
},
|
||||
]
|
||||
|
||||
assert messages == expected
|
||||
|
||||
def test_mixed_prepared_and_unprepared_messages(self):
|
||||
"""Test handling of mixed prepared and unprepared messages."""
|
||||
@ -135,12 +176,22 @@ class TestPrepareMultimodalMessages:
|
||||
{"role": "user", "content": "What about the grass?"},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
image = Image.new("RGB", (32, 32), color="red")
|
||||
messages = prepare_multimodal_messages(messages, images=[image])
|
||||
|
||||
expected = [
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant"}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is blue."}],
|
||||
},
|
||||
]
|
||||
|
||||
assert messages == expected
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sequence
|
||||
from itertools import takewhile
|
||||
@ -28,19 +29,30 @@ from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)
|
||||
|
||||
|
||||
def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None:
|
||||
def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> list[dict[str, Any]]:
|
||||
# docstyle-ignore # because <Image> is not parsable in the code block
|
||||
"""
|
||||
Convert messages into a structured multimodal format if needed.
|
||||
|
||||
Each message's content is transformed from a raw string into a list of typed parts. The first user message is
|
||||
prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries.
|
||||
Convert messages into a structured multimodal format and inject the provided images into the message contents.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
Messages with `"role"` and `"content"`. Content may be a raw string before transformation.
|
||||
num_images (`int`):
|
||||
Number of images to include in the first user message. This is used to determine how many image
|
||||
placeholders to add.
|
||||
Messages with `"role"` and `"content"`. Content may be a raw string before transformation. List of messages
|
||||
a `"role"` key (`"system"`, `"user"`, or `"assistant"`) and a `"content"` key containing either a string or
|
||||
a list of structured blocks if already prepared.
|
||||
images (`list`):
|
||||
List of image objects to insert.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured
|
||||
content blocks, and all `"image"` placeholders are populated with the corresponding image objects.
|
||||
|
||||
Notes:
|
||||
- When the input `messages` isn't already in the structured format, (i.e., all `"content"` values are strings),
|
||||
the function transforms them into the structured format by wrapping text in `{"type": "text", "text": ...}`
|
||||
and inserting `{"type": "image"}` placeholders for the images *before* the first user message.
|
||||
- When the input `messages` is already in the structured format (i.e., all `"content"` values are lists of
|
||||
structured blocks), the function only fills in the actual images in the existing `{"type": "image"}`
|
||||
placeholders. If the number of placeholders does not match the number of provided images, an error is raised.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@ -50,24 +62,28 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int)
|
||||
{"role": "assistant", "content": "It looks like a cat."},
|
||||
]
|
||||
|
||||
# Output (num_images=1)
|
||||
# Output, one image provided
|
||||
[
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]},
|
||||
{"role": "user", "content": [{"type": "image", "image": <PIL.Image.Image>}, {"type": "text", "text": "What's in this image?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]},
|
||||
]
|
||||
```
|
||||
"""
|
||||
image_included = False
|
||||
|
||||
messages = copy.deepcopy(messages) # avoid modifying the original messages
|
||||
|
||||
# First, convert all messages to the structured format if needed, and insert image placeholders if needed
|
||||
images_included = False
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
if isinstance(message["content"], str): # if already prepared, the content will be a list
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
elif message["role"] == "user":
|
||||
if isinstance(message["content"], str) and not image_included:
|
||||
placeholders = [{"type": "image"}] * num_images
|
||||
message["content"] = [*placeholders, {"type": "text", "text": message["content"]}]
|
||||
image_included = True
|
||||
elif isinstance(message["content"], str) and image_included:
|
||||
if isinstance(message["content"], str) and not images_included:
|
||||
image_entries = [{"type": "image"}] * len(images)
|
||||
message["content"] = [*image_entries, {"type": "text", "text": message["content"]}]
|
||||
images_included = True
|
||||
elif isinstance(message["content"], str) and images_included:
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
elif message["role"] == "assistant":
|
||||
if isinstance(message["content"], str):
|
||||
@ -75,6 +91,55 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int)
|
||||
else:
|
||||
raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.")
|
||||
|
||||
# Then, check that the number of image placeholders matches the number of images provided
|
||||
num_placeholders = sum(sum(1 for part in message["content"] if part["type"] == "image") for message in messages)
|
||||
if num_placeholders != len(images):
|
||||
raise ValueError(
|
||||
f"Number of images provided ({len(images)}) does not match number of image placeholders ({num_placeholders})."
|
||||
)
|
||||
|
||||
# Then, fill in the actual images in the placeholders
|
||||
img_idx = 0
|
||||
for message in messages:
|
||||
for part in message["content"]:
|
||||
if part["type"] == "image":
|
||||
part["image"] = images[img_idx]
|
||||
img_idx += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
# docstyle-ignore # because <Image> is not parsable in the code block
|
||||
"""
|
||||
Convert structured multimodal messages into a format compatible with vLLM. Replaces `"type": "image"` blocks with
|
||||
`"type": "image_pil"` blocks, and `"image": Image` with `"image_pil": Image`.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
Messages with `"role"` and `"content"`. Content is expected to be a list of structured blocks.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A deep-copied list of messages compatible with vLLM's expected input format.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Input
|
||||
[{"role": "user", "content": [{"type": "image", "image": <PIL.Image.Image>}, {"type": "text", "text": "What's in this image?"}]}]
|
||||
|
||||
# Output
|
||||
[{"role": "user", "content": [{"type": "image_pil", "image_pil": <PIL.Image.Image>}, {"type": "text", "text": "What's in this image?"}]}]
|
||||
```
|
||||
"""
|
||||
messages = copy.deepcopy(messages) # avoid modifying the original messages
|
||||
for message in messages:
|
||||
for part in message["content"]:
|
||||
if part["type"] == "image":
|
||||
part["type"] = "image_pil" # vLLM expects 'image_pil' key for images
|
||||
part["image_pil"] = part.pop("image")
|
||||
return messages
|
||||
|
||||
|
||||
def is_conversational(example: dict[str, Any]) -> bool:
|
||||
r"""
|
||||
|
@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
import torch
|
||||
from accelerate.utils import gather_object
|
||||
|
||||
from ...data_utils import is_conversational
|
||||
from ...data_utils import apply_chat_template, is_conversational
|
||||
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
|
||||
from ...trainer.utils import nanmax, nanmin, nanstd, pad
|
||||
|
||||
@ -80,13 +80,9 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
if images is not None and all(img_list == [] for img_list in images):
|
||||
images = None
|
||||
|
||||
(
|
||||
prompt_ids_list,
|
||||
completion_ids_list,
|
||||
num_items_in_batch,
|
||||
sampling_per_token_logps_list,
|
||||
forward_kwargs,
|
||||
) = self._generate(prompts, images)
|
||||
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
|
||||
prompts, images
|
||||
)
|
||||
|
||||
# Convert lists of token IDs to padded tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@ -112,6 +108,23 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||
|
||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||
|
||||
# Get forward_kwargs for models with multimodal inputs
|
||||
if images is not None:
|
||||
prompts_text = [
|
||||
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# If token_type_ids are used, extend them with zeros for the completion part
|
||||
if "token_type_ids" in forward_kwargs:
|
||||
token_type_ids = forward_kwargs["token_type_ids"]
|
||||
@ -119,11 +132,6 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
||||
)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||
|
||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||
|
||||
with torch.no_grad():
|
||||
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
||||
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
||||
|
@ -44,9 +44,14 @@ from transformers import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available
|
||||
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
|
||||
from ..data_utils import (
|
||||
apply_chat_template,
|
||||
is_conversational,
|
||||
prepare_multimodal_messages,
|
||||
prepare_multimodal_messages_vllm,
|
||||
)
|
||||
from ..extras.profiling import profiling_context, profiling_decorator
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_liger_kernel_available, is_vllm_available
|
||||
@ -1069,30 +1074,9 @@ class GRPOTrainer(BaseTrainer):
|
||||
rewards_per_func = gather(rewards_per_func)
|
||||
return rewards_per_func
|
||||
|
||||
def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
|
||||
def _generate_single_turn(self, prompts: list):
|
||||
device = self.accelerator.device
|
||||
|
||||
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
||||
# [{"role": "user", "content": "What color is the sky?"}] to
|
||||
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
|
||||
kwargs = {}
|
||||
if images is not None:
|
||||
kwargs = {"images": images}
|
||||
for prompt, image_list in zip(prompts, images):
|
||||
if isinstance(prompt, list): # i.e., when using conversational data
|
||||
prepare_multimodal_messages(prompt, num_images=len(image_list))
|
||||
|
||||
prompts_text = [
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
|
||||
@ -1105,38 +1089,35 @@ class GRPOTrainer(BaseTrainer):
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
if self.vllm_mode == "server":
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if images is not None:
|
||||
all_images = gather_object(images)
|
||||
all_prompts = gather_object(prompts)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
|
||||
|
||||
if images is not None:
|
||||
ordered_set_of_images = all_images[:: self.num_generations]
|
||||
else:
|
||||
ordered_set_of_images = None
|
||||
ordered_set_of_prompts = all_prompts[:: self.num_generations]
|
||||
|
||||
sampling_params = {
|
||||
"n": self.num_generations,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding_regex": self.guided_decoding_regex,
|
||||
"generation_kwargs": self.args.generation_kwargs,
|
||||
}
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
output = self.vllm_client.generate(
|
||||
prompts=ordered_set_of_prompts,
|
||||
images=ordered_set_of_images,
|
||||
n=self.num_generations,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
|
||||
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
|
||||
else:
|
||||
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
|
||||
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
|
||||
else:
|
||||
payload = None
|
||||
@ -1183,31 +1164,18 @@ class GRPOTrainer(BaseTrainer):
|
||||
if self.vllm_tensor_parallel_size > 1:
|
||||
# Gather prompts from all ranks in the TP group and flatten.
|
||||
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
|
||||
orig_size = len(prompts_text)
|
||||
orig_size = len(prompts)
|
||||
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
|
||||
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
|
||||
|
||||
if images is not None:
|
||||
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
|
||||
all_images = [img for sublist in gathered_images for img in sublist]
|
||||
else:
|
||||
all_images = None
|
||||
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
|
||||
all_prompts = [p for sublist in gathered_prompts for p in sublist]
|
||||
else:
|
||||
all_prompts_text = prompts_text
|
||||
all_images = images
|
||||
|
||||
if images is not None and all_images:
|
||||
vllm_inputs = []
|
||||
for prompt, image_list in zip(all_prompts_text, all_images):
|
||||
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
|
||||
|
||||
else:
|
||||
vllm_inputs = all_prompts_text
|
||||
all_prompts = prompts
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
|
||||
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||
@ -1234,15 +1202,20 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.llm.sleep(level=1)
|
||||
|
||||
elif self.use_transformers_paged:
|
||||
# Re-process inputs for paged generation if needed
|
||||
# Note: images are already validated and preprocessed above
|
||||
paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
|
||||
previous_attn = self.model_wrapped.config._attn_implementation
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
self.model_wrapped.config._attn_implementation = "paged_attention"
|
||||
processor_kwargs = {
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
)
|
||||
else:
|
||||
self.model_wrapped.config._attn_implementation = "sdpa_paged"
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
generate_inputs["inputs"] = generate_inputs.pop("input_ids")
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate_batch"),
|
||||
unwrap_model_for_generation(
|
||||
@ -1258,27 +1231,30 @@ class GRPOTrainer(BaseTrainer):
|
||||
unwrapped_model.to(torch.float16)
|
||||
with torch.inference_mode():
|
||||
all_outputs = unwrapped_model.generate_batch(
|
||||
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
|
||||
**generate_inputs, generation_config=self.generation_config, progress_bar=False
|
||||
)
|
||||
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
|
||||
completion_ids = [output.generated_tokens for output in all_outputs.values()]
|
||||
prompt_ids = paged_prompt_inputs.input_ids
|
||||
# Restore the original attention implementation, training mode
|
||||
self.model_wrapped.config._attn_implementation = previous_attn
|
||||
prompt_ids = generate_inputs["inputs"]
|
||||
logprobs = None # not used in this case
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
processor_kwargs = {
|
||||
"return_tensors": "pt",
|
||||
"padding": True,
|
||||
"padding_side": "left",
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
)
|
||||
else:
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
@ -1307,13 +1283,13 @@ class GRPOTrainer(BaseTrainer):
|
||||
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
||||
logprobs = None # not used in this case
|
||||
|
||||
return prompt_ids, completion_ids, logprobs, forward_kwargs
|
||||
return prompt_ids, completion_ids, logprobs
|
||||
|
||||
def _generate(self, prompts: list[str], images: Optional[list]):
|
||||
def _generate(self, prompts: list[str]):
|
||||
device = self.accelerator.device
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
|
||||
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts)
|
||||
|
||||
# Get completion length per sequence, used for logging
|
||||
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
|
||||
@ -1345,7 +1321,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
||||
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
||||
|
||||
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
|
||||
return prompt_ids, completion_ids, total_completion_tokens, logprobs
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||
@ -1365,13 +1341,15 @@ class GRPOTrainer(BaseTrainer):
|
||||
if images is not None and all(img_list == [] for img_list in images):
|
||||
images = None
|
||||
|
||||
(
|
||||
prompt_ids_list,
|
||||
completion_ids_list,
|
||||
num_items_in_batch,
|
||||
sampling_per_token_logps_list,
|
||||
forward_kwargs,
|
||||
) = self._generate(prompts, images)
|
||||
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
||||
# [{"role": "user", "content": "What color is the sky?"}] to
|
||||
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
|
||||
if images is not None:
|
||||
prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)]
|
||||
|
||||
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
|
||||
prompts
|
||||
)
|
||||
|
||||
# Convert lists of token IDs to padded tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@ -1397,6 +1375,23 @@ class GRPOTrainer(BaseTrainer):
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||
|
||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||
|
||||
# Get forward_kwargs for models with multimodal inputs
|
||||
if images is not None:
|
||||
prompts_text = [
|
||||
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# If token_type_ids are used, extend them with zeros for the completion part
|
||||
if "token_type_ids" in forward_kwargs:
|
||||
token_type_ids = forward_kwargs["token_type_ids"]
|
||||
@ -1404,11 +1399,6 @@ class GRPOTrainer(BaseTrainer):
|
||||
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
||||
)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||
|
||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||
|
||||
with torch.no_grad():
|
||||
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
||||
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
||||
|
@ -45,9 +45,14 @@ from transformers import (
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available
|
||||
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
|
||||
from ..data_utils import (
|
||||
apply_chat_template,
|
||||
is_conversational,
|
||||
prepare_multimodal_messages,
|
||||
prepare_multimodal_messages_vllm,
|
||||
)
|
||||
from ..extras.profiling import profiling_context, profiling_decorator
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_vllm_available
|
||||
@ -1065,30 +1070,9 @@ class RLOOTrainer(BaseTrainer):
|
||||
rewards_per_func = gather(rewards_per_func)
|
||||
return rewards_per_func
|
||||
|
||||
def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
|
||||
def _generate_single_turn(self, prompts: list):
|
||||
device = self.accelerator.device
|
||||
|
||||
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
||||
# [{"role": "user", "content": "What color is the sky?"}] to
|
||||
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
|
||||
kwargs = {}
|
||||
if images is not None:
|
||||
kwargs = {"images": images}
|
||||
for prompt, image_list in zip(prompts, images):
|
||||
if isinstance(prompt, list): # i.e., when using conversational data
|
||||
prepare_multimodal_messages(prompt, num_images=len(image_list))
|
||||
|
||||
prompts_text = [
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
|
||||
@ -1101,38 +1085,35 @@ class RLOOTrainer(BaseTrainer):
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
if self.vllm_mode == "server":
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if images is not None:
|
||||
all_images = gather_object(images)
|
||||
all_prompts = gather_object(prompts)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
|
||||
|
||||
if images is not None:
|
||||
ordered_set_of_images = all_images[:: self.num_generations]
|
||||
else:
|
||||
ordered_set_of_images = None
|
||||
ordered_set_of_prompts = all_prompts[:: self.num_generations]
|
||||
|
||||
sampling_params = {
|
||||
"n": self.num_generations,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding_regex": self.guided_decoding_regex,
|
||||
"generation_kwargs": self.args.generation_kwargs,
|
||||
}
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
output = self.vllm_client.generate(
|
||||
prompts=ordered_set_of_prompts,
|
||||
images=ordered_set_of_images,
|
||||
n=self.num_generations,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
|
||||
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
|
||||
else:
|
||||
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
|
||||
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
|
||||
else:
|
||||
payload = None
|
||||
@ -1177,31 +1158,18 @@ class RLOOTrainer(BaseTrainer):
|
||||
if self.vllm_tensor_parallel_size > 1:
|
||||
# Gather prompts from all ranks in the TP group and flatten.
|
||||
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
|
||||
orig_size = len(prompts_text)
|
||||
orig_size = len(prompts)
|
||||
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
|
||||
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
|
||||
|
||||
if images is not None:
|
||||
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
|
||||
all_images = [img for sublist in gathered_images for img in sublist]
|
||||
else:
|
||||
all_images = None
|
||||
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
|
||||
all_prompts = [p for sublist in gathered_prompts for p in sublist]
|
||||
else:
|
||||
all_prompts_text = prompts_text
|
||||
all_images = images
|
||||
|
||||
if images is not None and all_images:
|
||||
vllm_inputs = []
|
||||
for prompt, image_list in zip(all_prompts_text, all_images):
|
||||
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
|
||||
|
||||
else:
|
||||
vllm_inputs = all_prompts_text
|
||||
all_prompts = prompts
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
|
||||
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||
@ -1221,15 +1189,20 @@ class RLOOTrainer(BaseTrainer):
|
||||
self.llm.sleep(level=1)
|
||||
|
||||
elif self.use_transformers_paged:
|
||||
# Re-process inputs for paged generation if needed
|
||||
# Note: images are already validated and preprocessed above
|
||||
paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
|
||||
previous_attn = self.model_wrapped.config._attn_implementation
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
self.model_wrapped.config._attn_implementation = "paged_attention"
|
||||
processor_kwargs = {
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
)
|
||||
else:
|
||||
self.model_wrapped.config._attn_implementation = "sdpa_paged"
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
generate_inputs["inputs"] = generate_inputs.pop("input_ids")
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate_batch"),
|
||||
unwrap_model_for_generation(
|
||||
@ -1245,26 +1218,29 @@ class RLOOTrainer(BaseTrainer):
|
||||
unwrapped_model.to(torch.float16)
|
||||
with torch.inference_mode():
|
||||
all_outputs = unwrapped_model.generate_batch(
|
||||
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
|
||||
**generate_inputs, generation_config=self.generation_config, progress_bar=False
|
||||
)
|
||||
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
|
||||
completion_ids = [output.generated_tokens for output in all_outputs.values()]
|
||||
prompt_ids = paged_prompt_inputs.input_ids
|
||||
# Restore the original attention implementation, training mode
|
||||
self.model_wrapped.config._attn_implementation = previous_attn
|
||||
prompt_ids = generate_inputs["inputs"]
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
processor_kwargs = {
|
||||
"return_tensors": "pt",
|
||||
"padding": True,
|
||||
"padding_side": "left",
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
)
|
||||
else:
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
@ -1292,13 +1268,13 @@ class RLOOTrainer(BaseTrainer):
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
||||
|
||||
return prompt_ids, completion_ids, forward_kwargs
|
||||
return prompt_ids, completion_ids
|
||||
|
||||
def _generate(self, prompts: list[str], images: Optional[list]):
|
||||
def _generate(self, prompts: list[str]):
|
||||
device = self.accelerator.device
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images)
|
||||
prompt_ids, completion_ids = self._generate_single_turn(prompts)
|
||||
|
||||
# Get completion length per sequence, used for logging
|
||||
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
|
||||
@ -1331,7 +1307,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
||||
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
||||
|
||||
return prompt_ids, completion_ids, forward_kwargs
|
||||
return prompt_ids, completion_ids
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||
@ -1351,7 +1327,13 @@ class RLOOTrainer(BaseTrainer):
|
||||
if images is not None and all(img_list == [] for img_list in images):
|
||||
images = None
|
||||
|
||||
prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)
|
||||
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
||||
# [{"role": "user", "content": "What color is the sky?"}] to
|
||||
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
|
||||
if images is not None:
|
||||
prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)]
|
||||
|
||||
prompt_ids_list, completion_ids_list = self._generate(prompts)
|
||||
|
||||
# Convert lists of token IDs to padded tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@ -1372,6 +1354,23 @@ class RLOOTrainer(BaseTrainer):
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||
|
||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||
|
||||
# Get forward_kwargs for models with multimodal inputs
|
||||
if images is not None:
|
||||
prompts_text = [
|
||||
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# If token_type_ids are used, extend them with zeros for the completion part
|
||||
if "token_type_ids" in forward_kwargs:
|
||||
token_type_ids = forward_kwargs["token_type_ids"]
|
||||
@ -1379,11 +1378,6 @@ class RLOOTrainer(BaseTrainer):
|
||||
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
||||
)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||
|
||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||
|
||||
with torch.no_grad():
|
||||
# Compute the per-token log probabilities for the current model
|
||||
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
|
@ -353,9 +353,7 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
images = None
|
||||
|
||||
if "messages" in examples[0]: # conversational case
|
||||
for example in examples:
|
||||
prepare_multimodal_messages(example["messages"], len(example["images"]))
|
||||
messages = [example["messages"] for example in examples]
|
||||
messages = [prepare_multimodal_messages(example["messages"], example["images"]) for example in examples]
|
||||
texts = self.processor.apply_chat_template(messages)
|
||||
elif self.dataset_text_field in examples[0]: # standard case
|
||||
texts = [example[self.dataset_text_field] for example in examples]
|
||||
@ -396,7 +394,8 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
images = None
|
||||
if is_conversational(examples[0]): # conversational case
|
||||
for example in examples:
|
||||
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
|
||||
example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"])
|
||||
example["completion"] = prepare_multimodal_messages(example["completion"], images=[])
|
||||
examples = [apply_chat_template(example, self.processor) for example in examples]
|
||||
|
||||
prompts = [example["prompt"] for example in examples]
|
||||
@ -951,10 +950,13 @@ class SFTTrainer(BaseTrainer):
|
||||
output = {}
|
||||
if is_conversational(example):
|
||||
if self._is_vlm:
|
||||
prepare_multimodal_messages(example["prompt"], num_images=0)
|
||||
prepare_multimodal_messages(example["completion"], num_images=0)
|
||||
prompt = prepare_multimodal_messages(example["prompt"], images=[])
|
||||
completion = prepare_multimodal_messages(example["completion"], images=[])
|
||||
else:
|
||||
prompt = example["prompt"]
|
||||
completion = example["completion"]
|
||||
prompt_ids = processing_class.apply_chat_template(
|
||||
example["prompt"],
|
||||
prompt,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=example.get("tools"),
|
||||
@ -964,7 +966,7 @@ class SFTTrainer(BaseTrainer):
|
||||
# even for single examples, while for LLMs it returns lists of ints.
|
||||
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
|
||||
prompt_completion_processed = processing_class.apply_chat_template(
|
||||
example["prompt"] + example["completion"],
|
||||
prompt + completion,
|
||||
return_dict=True,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=assistant_only_loss,
|
||||
@ -1002,9 +1004,11 @@ class SFTTrainer(BaseTrainer):
|
||||
else: # language modeling case
|
||||
if is_conversational(example):
|
||||
if self._is_vlm:
|
||||
prepare_multimodal_messages(example["messages"], num_images=0)
|
||||
messages = prepare_multimodal_messages(example["messages"], images=[])
|
||||
else:
|
||||
messages = example["messages"]
|
||||
processed = processing_class.apply_chat_template(
|
||||
example["messages"],
|
||||
messages,
|
||||
return_dict=True,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=assistant_only_loss,
|
||||
|
Reference in New Issue
Block a user