♻️ Reuse multimodal message preparation from SFTTrainer in GRPOTrainer (#3919)

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
MQY
2025-08-21 01:04:54 +08:00
committed by GitHub
parent 8793a46760
commit 85ead751f5
6 changed files with 180 additions and 62 deletions

View File

@ -1,9 +1,17 @@
# Data Utilities
## prepare_multimodal_messages
[[autodoc]] prepare_multimodal_messages
## is_conversational
[[autodoc]] is_conversational
## is_conversational_from_value
[[autodoc]] is_conversational_from_value
## apply_chat_template
[[autodoc]] apply_chat_template
@ -13,7 +21,7 @@
[[autodoc]] maybe_apply_chat_template
## maybe_convert_to_chatml
[[autodoc]] maybe_convert_to_chatml
## extract_prompt

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import itertools
import textwrap
import unittest
@ -31,6 +32,7 @@ from trl.data_utils import (
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
prepare_multimodal_messages,
truncate_dataset,
unpair_preference_dataset,
)
@ -38,6 +40,113 @@ from trl.data_utils import (
from .testing_utils import TrlTestCase
class PrepareMultimodalMessagesTester(unittest.TestCase):
def test_basic_user_assistant_conversation(self):
"""Test basic conversation with user and assistant messages."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
prepare_multimodal_messages(messages, num_images=1)
expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]
self.assertEqual(messages, expected)
def test_first_user_message_gets_image(self):
"""Test that only the first user message gets an image placeholder."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
{"role": "user", "content": "How about the grass?"},
]
prepare_multimodal_messages(messages, num_images=1)
expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
]
self.assertEqual(messages, expected)
def test_multiple_images(self):
"""Test that multiple images are added to the first user message."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
prepare_multimodal_messages(messages, num_images=3)
expected = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "image"},
{"type": "text", "text": "What color is the sky?"},
],
},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]
self.assertEqual(messages, expected)
def test_system_message_transformation(self):
"""Test that system messages are properly transformed."""
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What color is the sky?"},
]
prepare_multimodal_messages(messages, num_images=1)
expected = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
]
self.assertEqual(messages, expected)
def test_already_prepared_messages_unchanged(self):
"""Test that messages with list content are not modified."""
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]
original = copy.deepcopy(messages)
prepare_multimodal_messages(messages, num_images=1)
self.assertEqual(messages, original)
def test_mixed_prepared_and_unprepared_messages(self):
"""Test handling of mixed prepared and unprepared messages."""
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": "What about the grass?"},
]
prepare_multimodal_messages(messages, num_images=1)
expected = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
]
self.assertEqual(messages, expected)
class IsConversationalTester(TrlTestCase):
conversational_examples = [
{ # Language modeling

View File

@ -25,11 +25,13 @@ _import_structure = {
"apply_chat_template",
"extract_prompt",
"is_conversational",
"is_conversational_from_value",
"maybe_apply_chat_template",
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_dataset",
"prepare_multimodal_messages",
"truncate_dataset",
"unpair_preference_dataset",
],
@ -117,11 +119,13 @@ if TYPE_CHECKING:
apply_chat_template,
extract_prompt,
is_conversational,
is_conversational_from_value,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
prepare_multimodal_messages,
truncate_dataset,
unpair_preference_dataset,
)

View File

@ -28,6 +28,54 @@ from transformers import PreTrainedTokenizerBase
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)
def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None:
"""
Convert messages into a structured multimodal format if needed.
Each message's content is transformed from a raw string into a list of typed parts. The first user message is
prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries.
Args:
messages (`list[dict[str, Any]]`):
Messages with `"role"` and `"content"`. Content may be a raw string before transformation.
num_images (`int`):
Number of images to include in the first user message. This is used to determine how many image
placeholders to add.
Example:
```python
# Input
[
{"role": "user", "content": "What's in this image?"},
{"role": "assistant", "content": "It looks like a cat."},
]
# Output (num_images=1)
[
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]},
]
```
"""
image_included = False
for message in messages:
if message["role"] == "system":
if isinstance(message["content"], str): # if already prepared, the content will be a list
message["content"] = [{"type": "text", "text": message["content"]}]
if message["role"] == "user":
if isinstance(message["content"], str) and not image_included:
placeholders = [{"type": "image"}] * num_images
message["content"] = [*placeholders, {"type": "text", "text": message["content"]}]
image_included = True
elif isinstance(message["content"], str) and image_included:
message["content"] = [{"type": "text", "text": message["content"]}]
if message["role"] == "assistant":
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
else:
raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.")
def is_conversational(example: dict[str, Any]) -> bool:
r"""
Check if the example is in a conversational format.
@ -123,7 +171,7 @@ def apply_chat_template(
prompt_chosen = tokenizer.apply_chat_template(
example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
)
# DeepSeek-R1 inserts a <think> token when using `add_generation_prompt`, which can cause discrepancies
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
# common prefix between the two. In most cases, this is a no-op.
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen)))
@ -133,14 +181,14 @@ def apply_chat_template(
prompt_rejected = tokenizer.apply_chat_template(
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
)
# Handle DeepSeek-R1 <think> token, see the above comment for details
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
)
# Handle DeepSeek-R1 <think> token, see the above comment for details
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
completion = prompt_completion[len(prompt) :]
else: # implicit prompt case
@ -222,7 +270,7 @@ def maybe_apply_chat_template(
... "completion": [{"role": "assistant", "content": "It is blue."}],
... }
>>> apply_chat_template(example, tokenizer)
{'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
{'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n'}
```
"""
if is_conversational(example):

View File

@ -50,7 +50,7 @@ from transformers import (
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
@ -1351,17 +1351,8 @@ class GRPOTrainer(Trainer):
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list):
for message in prompt:
if not isinstance(message, dict):
continue
content = message.get("content")
role = message.get("role")
if isinstance(content, str):
if role == "user":
message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
elif role == "system":
message["content"] = [{"type": "text", "text": content}]
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

View File

@ -49,6 +49,7 @@ from ..data_utils import (
is_conversational_from_value,
maybe_convert_to_chatml,
pack_dataset,
prepare_multimodal_messages,
truncate_dataset,
)
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
@ -375,55 +376,12 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
else:
raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.")
@staticmethod
def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None:
"""
Convert messages into a structured multimodal format if needed.
Each message's content is transformed from a raw string into a list of typed parts. The first user message is
prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries.
Args:
messages (`list[dict[str, Any]]`):
Messages with "role" and "content". Content may be a raw string before transformation.
num_images (`int`):
Number of images to include in the first user message. This is used to determine how many image
placeholders to add.
Example:
```python
# Input
[
{"role": "user", "content": "What's in this image?"},
{"role": "assistant", "content": "It looks like a cat."},
]
# Output (num_images=1)
[
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]},
]
```
"""
image_included = False
for message in messages:
if message["role"] == "user":
if isinstance(message["content"], str) and not image_included:
placeholders = [{"type": "image"}] * num_images
message["content"] = [*placeholders, {"type": "text", "text": message["content"]}]
image_included = True
elif isinstance(message["content"], str) and image_included:
message["content"] = [{"type": "text", "text": message["content"]}]
if message["role"] == "assistant":
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
images = [example["images"] for example in examples]
if "messages" in examples[0]: # conversational case
for example in examples:
self.prepare_multimodal_messages(example["messages"], len(example["images"]))
prepare_multimodal_messages(example["messages"], len(example["images"]))
messages = [example["messages"] for example in examples]
texts = self.processor.apply_chat_template(messages)
elif self.dataset_text_field in examples[0]: # standard case
@ -462,7 +420,7 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
images = [example["images"] for example in examples]
if is_conversational(examples[0]): # conversational case
for example in examples:
self.prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
examples = [apply_chat_template(example, self.processor) for example in examples]
prompts = [example["prompt"] for example in examples]