mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
♻️ Reuse multimodal message preparation from SFTTrainer
in GRPOTrainer
(#3919)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
@ -1,9 +1,17 @@
|
||||
# Data Utilities
|
||||
|
||||
## prepare_multimodal_messages
|
||||
|
||||
[[autodoc]] prepare_multimodal_messages
|
||||
|
||||
## is_conversational
|
||||
|
||||
[[autodoc]] is_conversational
|
||||
|
||||
## is_conversational_from_value
|
||||
|
||||
[[autodoc]] is_conversational_from_value
|
||||
|
||||
## apply_chat_template
|
||||
|
||||
[[autodoc]] apply_chat_template
|
||||
@ -13,7 +21,7 @@
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
import unittest
|
||||
@ -31,6 +32,7 @@ from trl.data_utils import (
|
||||
maybe_extract_prompt,
|
||||
maybe_unpair_preference_dataset,
|
||||
pack_dataset,
|
||||
prepare_multimodal_messages,
|
||||
truncate_dataset,
|
||||
unpair_preference_dataset,
|
||||
)
|
||||
@ -38,6 +40,113 @@ from trl.data_utils import (
|
||||
from .testing_utils import TrlTestCase
|
||||
|
||||
|
||||
class PrepareMultimodalMessagesTester(unittest.TestCase):
|
||||
def test_basic_user_assistant_conversation(self):
|
||||
"""Test basic conversation with user and assistant messages."""
|
||||
messages = [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
|
||||
expected = [
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
]
|
||||
|
||||
self.assertEqual(messages, expected)
|
||||
|
||||
def test_first_user_message_gets_image(self):
|
||||
"""Test that only the first user message gets an image placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
{"role": "user", "content": "How about the grass?"},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
|
||||
expected = [
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
|
||||
]
|
||||
|
||||
self.assertEqual(messages, expected)
|
||||
|
||||
def test_multiple_images(self):
|
||||
"""Test that multiple images are added to the first user message."""
|
||||
messages = [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=3)
|
||||
|
||||
expected = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What color is the sky?"},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
]
|
||||
|
||||
self.assertEqual(messages, expected)
|
||||
|
||||
def test_system_message_transformation(self):
|
||||
"""Test that system messages are properly transformed."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
|
||||
expected = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
]
|
||||
|
||||
self.assertEqual(messages, expected)
|
||||
|
||||
def test_already_prepared_messages_unchanged(self):
|
||||
"""Test that messages with list content are not modified."""
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
]
|
||||
|
||||
original = copy.deepcopy(messages)
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
|
||||
self.assertEqual(messages, original)
|
||||
|
||||
def test_mixed_prepared_and_unprepared_messages(self):
|
||||
"""Test handling of mixed prepared and unprepared messages."""
|
||||
messages = [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{"role": "user", "content": "What about the grass?"},
|
||||
]
|
||||
|
||||
prepare_multimodal_messages(messages, num_images=1)
|
||||
|
||||
expected = [
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
||||
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
|
||||
]
|
||||
|
||||
self.assertEqual(messages, expected)
|
||||
|
||||
|
||||
class IsConversationalTester(TrlTestCase):
|
||||
conversational_examples = [
|
||||
{ # Language modeling
|
||||
|
@ -25,11 +25,13 @@ _import_structure = {
|
||||
"apply_chat_template",
|
||||
"extract_prompt",
|
||||
"is_conversational",
|
||||
"is_conversational_from_value",
|
||||
"maybe_apply_chat_template",
|
||||
"maybe_convert_to_chatml",
|
||||
"maybe_extract_prompt",
|
||||
"maybe_unpair_preference_dataset",
|
||||
"pack_dataset",
|
||||
"prepare_multimodal_messages",
|
||||
"truncate_dataset",
|
||||
"unpair_preference_dataset",
|
||||
],
|
||||
@ -117,11 +119,13 @@ if TYPE_CHECKING:
|
||||
apply_chat_template,
|
||||
extract_prompt,
|
||||
is_conversational,
|
||||
is_conversational_from_value,
|
||||
maybe_apply_chat_template,
|
||||
maybe_convert_to_chatml,
|
||||
maybe_extract_prompt,
|
||||
maybe_unpair_preference_dataset,
|
||||
pack_dataset,
|
||||
prepare_multimodal_messages,
|
||||
truncate_dataset,
|
||||
unpair_preference_dataset,
|
||||
)
|
||||
|
@ -28,6 +28,54 @@ from transformers import PreTrainedTokenizerBase
|
||||
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)
|
||||
|
||||
|
||||
def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None:
|
||||
"""
|
||||
Convert messages into a structured multimodal format if needed.
|
||||
|
||||
Each message's content is transformed from a raw string into a list of typed parts. The first user message is
|
||||
prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
Messages with `"role"` and `"content"`. Content may be a raw string before transformation.
|
||||
num_images (`int`):
|
||||
Number of images to include in the first user message. This is used to determine how many image
|
||||
placeholders to add.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Input
|
||||
[
|
||||
{"role": "user", "content": "What's in this image?"},
|
||||
{"role": "assistant", "content": "It looks like a cat."},
|
||||
]
|
||||
|
||||
# Output (num_images=1)
|
||||
[
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]},
|
||||
]
|
||||
```
|
||||
"""
|
||||
image_included = False
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
if isinstance(message["content"], str): # if already prepared, the content will be a list
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str) and not image_included:
|
||||
placeholders = [{"type": "image"}] * num_images
|
||||
message["content"] = [*placeholders, {"type": "text", "text": message["content"]}]
|
||||
image_included = True
|
||||
elif isinstance(message["content"], str) and image_included:
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
if message["role"] == "assistant":
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
else:
|
||||
raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.")
|
||||
|
||||
|
||||
def is_conversational(example: dict[str, Any]) -> bool:
|
||||
r"""
|
||||
Check if the example is in a conversational format.
|
||||
@ -123,7 +171,7 @@ def apply_chat_template(
|
||||
prompt_chosen = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
|
||||
)
|
||||
# DeepSeek-R1 inserts a <think> token when using `add_generation_prompt`, which can cause discrepancies
|
||||
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
|
||||
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
|
||||
# common prefix between the two. In most cases, this is a no-op.
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen)))
|
||||
@ -133,14 +181,14 @@ def apply_chat_template(
|
||||
prompt_rejected = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
|
||||
)
|
||||
# Handle DeepSeek-R1 <think> token, see the above comment for details
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
|
||||
rejected = prompt_rejected[len(prompt) :]
|
||||
if "completion" in example:
|
||||
prompt_completion = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
|
||||
)
|
||||
# Handle DeepSeek-R1 <think> token, see the above comment for details
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
|
||||
completion = prompt_completion[len(prompt) :]
|
||||
else: # implicit prompt case
|
||||
@ -222,7 +270,7 @@ def maybe_apply_chat_template(
|
||||
... "completion": [{"role": "assistant", "content": "It is blue."}],
|
||||
... }
|
||||
>>> apply_chat_template(example, tokenizer)
|
||||
{'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
|
||||
{'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n'}
|
||||
```
|
||||
"""
|
||||
if is_conversational(example):
|
||||
|
@ -50,7 +50,7 @@ from transformers import (
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
|
||||
from ..extras.profiling import profiling_context, profiling_decorator
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_liger_kernel_available, is_vllm_available
|
||||
@ -1351,17 +1351,8 @@ class GRPOTrainer(Trainer):
|
||||
images = [example.get("image") for example in inputs]
|
||||
kwargs = {"images": [[img] for img in images]}
|
||||
for prompt in prompts:
|
||||
if isinstance(prompt, list):
|
||||
for message in prompt:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
content = message.get("content")
|
||||
role = message.get("role")
|
||||
if isinstance(content, str):
|
||||
if role == "user":
|
||||
message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
|
||||
elif role == "system":
|
||||
message["content"] = [{"type": "text", "text": content}]
|
||||
if isinstance(prompt, list): # i.e., when using conversational data
|
||||
prepare_multimodal_messages(prompt, num_images=1)
|
||||
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
||||
|
||||
|
@ -49,6 +49,7 @@ from ..data_utils import (
|
||||
is_conversational_from_value,
|
||||
maybe_convert_to_chatml,
|
||||
pack_dataset,
|
||||
prepare_multimodal_messages,
|
||||
truncate_dataset,
|
||||
)
|
||||
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
|
||||
@ -375,55 +376,12 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
else:
|
||||
raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.")
|
||||
|
||||
@staticmethod
|
||||
def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None:
|
||||
"""
|
||||
Convert messages into a structured multimodal format if needed.
|
||||
|
||||
Each message's content is transformed from a raw string into a list of typed parts. The first user message is
|
||||
prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
Messages with "role" and "content". Content may be a raw string before transformation.
|
||||
num_images (`int`):
|
||||
Number of images to include in the first user message. This is used to determine how many image
|
||||
placeholders to add.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Input
|
||||
[
|
||||
{"role": "user", "content": "What's in this image?"},
|
||||
{"role": "assistant", "content": "It looks like a cat."},
|
||||
]
|
||||
|
||||
# Output (num_images=1)
|
||||
[
|
||||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]},
|
||||
]
|
||||
```
|
||||
"""
|
||||
image_included = False
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str) and not image_included:
|
||||
placeholders = [{"type": "image"}] * num_images
|
||||
message["content"] = [*placeholders, {"type": "text", "text": message["content"]}]
|
||||
image_included = True
|
||||
elif isinstance(message["content"], str) and image_included:
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
if message["role"] == "assistant":
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
|
||||
def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
images = [example["images"] for example in examples]
|
||||
|
||||
if "messages" in examples[0]: # conversational case
|
||||
for example in examples:
|
||||
self.prepare_multimodal_messages(example["messages"], len(example["images"]))
|
||||
prepare_multimodal_messages(example["messages"], len(example["images"]))
|
||||
messages = [example["messages"] for example in examples]
|
||||
texts = self.processor.apply_chat_template(messages)
|
||||
elif self.dataset_text_field in examples[0]: # standard case
|
||||
@ -462,7 +420,7 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
images = [example["images"] for example in examples]
|
||||
if is_conversational(examples[0]): # conversational case
|
||||
for example in examples:
|
||||
self.prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
|
||||
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
|
||||
examples = [apply_chat_template(example, self.processor) for example in examples]
|
||||
|
||||
prompts = [example["prompt"] for example in examples]
|
||||
|
Reference in New Issue
Block a user