mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
10 Commits
v0.22.0
...
v0.15-rele
Author | SHA1 | Date | |
---|---|---|---|
fc2b041b58 | |||
d4098d1dec | |||
e437710883 | |||
51d383efca | |||
38c33c547f | |||
8aa13d809c | |||
99c78f123f | |||
2ccca1d0aa | |||
7596db89da | |||
def8e48dce |
@ -12,6 +12,10 @@
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
9
setup.py
9
setup.py
@ -71,7 +71,7 @@ To create the package for PyPI.
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.15.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.15.2" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"accelerate>=0.34.0",
|
||||
@ -85,13 +85,16 @@ EXTRAS = {
|
||||
"diffusers": ["diffusers>=0.18.0"],
|
||||
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
|
||||
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
|
||||
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
|
||||
# can be set to >=0.5.3 when https://github.com/linkedin/Liger-Kernel/issues/586 is fixed
|
||||
"liger": ["liger-kernel==0.5.3; sys_platform != 'win32'"],
|
||||
"mergekit": ["mergekit>=0.0.5.1"],
|
||||
"peft": ["peft>=0.8.0"],
|
||||
"quantization": ["bitsandbytes"],
|
||||
"scikit": ["scikit-learn"],
|
||||
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
|
||||
"vllm": ["vllm>=0.7.1; sys_platform != 'win32'"], # vllm is not available on Windows
|
||||
# vllm is not available on Windows
|
||||
# vllm 0.7.3 causes hanging while gathering. temporary pinning the version until the issue is resolved
|
||||
"vllm": ["vllm==0.7.2; sys_platform != 'win32'"],
|
||||
"vlm": ["Pillow"],
|
||||
}
|
||||
EXTRAS["dev"] = []
|
||||
|
@ -24,6 +24,7 @@ from trl.data_utils import (
|
||||
extract_prompt,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
maybe_convert_to_chatml,
|
||||
maybe_extract_prompt,
|
||||
maybe_unpair_preference_dataset,
|
||||
pack_examples,
|
||||
@ -435,6 +436,51 @@ class TestPackExamples(unittest.TestCase):
|
||||
self.assertEqual(dataset.to_dict(), expected_output)
|
||||
|
||||
|
||||
class TestMaybeConvertToChatML(unittest.TestCase):
|
||||
def test_with_conversations_key(self):
|
||||
# Particular case where the key is "conversations": we rename it to "messages"
|
||||
example = {
|
||||
"conversations": [
|
||||
{"from": "user", "value": "What color is the sky?"},
|
||||
{"from": "assistant", "value": "It is blue."},
|
||||
]
|
||||
}
|
||||
expected_output = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
|
||||
|
||||
def test_without_conversations_key(self):
|
||||
# Same as before, but we don't rename the keys
|
||||
example = {
|
||||
"prompt": [{"from": "user", "value": "What color is the sky?"}],
|
||||
"completion": [{"from": "assistant", "value": "It is blue."}],
|
||||
}
|
||||
expected_output = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}],
|
||||
}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
|
||||
|
||||
def test_not_conversional(self):
|
||||
# When not needed, the example should remain unchanged
|
||||
example = {"text": "The sky is blue."}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), example)
|
||||
|
||||
def test_already_chatml(self):
|
||||
# When the example is already in ChatML format, it should remain unchanged
|
||||
example = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), example)
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -288,7 +288,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
|
||||
def test_sft_trainer_with_pretokenzied_data_packing(self):
|
||||
def test_sft_trainer_with_pretokenized_data_packing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = SFTConfig(
|
||||
output_dir=tmp_dir,
|
||||
@ -1370,3 +1370,65 @@ class SFTTrainerTester2(unittest.TestCase):
|
||||
"base_layer" not in n
|
||||
): # We expect the peft parameters to be different (except for the base layer)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
def test_train_with_non_chatml_conversational_data(self):
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||
|
||||
# Rename role/content to from/value to ensure SFT works with non-chatML conversational data
|
||||
def rename_fields(example: list[dict]):
|
||||
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}
|
||||
|
||||
dataset = dataset.map(rename_fields, remove_columns="messages")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
|
||||
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
def test_sft_trainer_with_pretokenized_data(self):
|
||||
# Get the model and dataset
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
||||
|
||||
def tokenize_example(example):
|
||||
return tokenizer(example["text"])
|
||||
|
||||
# Apply tokenization
|
||||
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
|
||||
trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
@ -27,7 +27,6 @@ from trl.trainer import compute_accuracy
|
||||
from trl.trainer.utils import (
|
||||
DataCollatorForChatML,
|
||||
batch_generation,
|
||||
compute_token_accuracy,
|
||||
decode_and_strip_padding,
|
||||
flush_left,
|
||||
generate_model_card,
|
||||
@ -456,60 +455,6 @@ class TestFlushLeft(unittest.TestCase):
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
|
||||
|
||||
class TestComputeTokenAccuracy(unittest.TestCase):
|
||||
def test_basic_accuracy(self):
|
||||
# Test basic accuracy computation
|
||||
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2]
|
||||
labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2]
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens
|
||||
|
||||
def test_with_ignore_index(self):
|
||||
# Test accuracy computation with ignored tokens
|
||||
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]])
|
||||
labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored
|
||||
accuracy = compute_token_accuracy(logits, labels, ignore_index=-100)
|
||||
self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens
|
||||
|
||||
def test_all_ignored(self):
|
||||
# Test case where all tokens are ignored
|
||||
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
|
||||
labels = torch.tensor([[-100, -100]])
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy
|
||||
|
||||
def test_perfect_accuracy(self):
|
||||
# Test case with 100% accuracy
|
||||
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
|
||||
labels = torch.tensor([[1, 0]])
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertEqual(accuracy, 1.0) # All predictions correct
|
||||
|
||||
def test_zero_accuracy(self):
|
||||
# Test case with 0% accuracy
|
||||
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
|
||||
labels = torch.tensor([[0, 1]])
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertEqual(accuracy, 0.0) # All predictions wrong
|
||||
|
||||
def test_batch_accuracy(self):
|
||||
# Test accuracy computation across multiple batches
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1
|
||||
[[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2
|
||||
]
|
||||
)
|
||||
labels = torch.tensor(
|
||||
[
|
||||
[1, 0, 1], # Batch 1
|
||||
[1, 0, -100], # Batch 2 (last token ignored)
|
||||
]
|
||||
)
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertAlmostEqual(accuracy, 0.8)
|
||||
|
||||
|
||||
class TestSelectiveLogSoftmax(unittest.TestCase):
|
||||
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
|
||||
def test_selective_log_softmax(self, dtype):
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.15.0"
|
||||
__version__ = "0.15.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -26,6 +26,7 @@ _import_structure = {
|
||||
"extract_prompt",
|
||||
"is_conversational",
|
||||
"maybe_apply_chat_template",
|
||||
"maybe_convert_to_chatml",
|
||||
"maybe_extract_prompt",
|
||||
"maybe_unpair_preference_dataset",
|
||||
"pack_examples",
|
||||
@ -101,7 +102,7 @@ _import_structure = {
|
||||
"XPOTrainer",
|
||||
],
|
||||
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -126,6 +127,7 @@ if TYPE_CHECKING:
|
||||
extract_prompt,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
maybe_convert_to_chatml,
|
||||
maybe_extract_prompt,
|
||||
maybe_unpair_preference_dataset,
|
||||
pack_examples,
|
||||
@ -202,7 +204,7 @@ if TYPE_CHECKING:
|
||||
XPOTrainer,
|
||||
)
|
||||
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
|
||||
from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
|
@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
|
||||
dataset type.
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if the data is in a conversational format, `False` otherwise.
|
||||
`bool`:
|
||||
`True` if the data is in a conversational format, `False` otherwise.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -185,20 +186,21 @@ def maybe_apply_chat_template(
|
||||
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
The tokenizer to apply the chat template with.
|
||||
Tokenizer to apply the chat template with.
|
||||
tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`):
|
||||
A list of tools (callable functions) that will be accessible to the model.
|
||||
If the template does not support function calling, this argument will have no effect
|
||||
|
||||
Returns:
|
||||
`dict[str, str]`: The formatted example with the chat template applied.
|
||||
`dict[str, str]`:
|
||||
Formatted example with the chat template applied.
|
||||
|
||||
Notes:
|
||||
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
|
||||
`"text"`.
|
||||
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced
|
||||
by `"text"`.
|
||||
|
||||
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
|
||||
if the last role is `"assistant"`, the final message is continued.
|
||||
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt.
|
||||
Else, if the last role is `"assistant"`, the final message is continued.
|
||||
|
||||
Example:
|
||||
|
||||
@ -462,3 +464,52 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
|
||||
# Split the values into chunks of size seq_length
|
||||
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
|
||||
return examples
|
||||
|
||||
|
||||
def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
|
||||
"""
|
||||
Convert a conversational dataset with fields `from` and `value` to ChatML format.
|
||||
|
||||
This function modifies conversational data to align with OpenAI's ChatML format:
|
||||
- Replaces the key `"from"` with `"role"` in message dictionaries.
|
||||
- Replaces the key `"value"` with `"content"` in message dictionaries.
|
||||
- Renames `"conversations"` to `"messages"` for consistency with ChatML.
|
||||
|
||||
Args:
|
||||
example (`dict[str, list]`):
|
||||
A single data entry containing a list of messages.
|
||||
|
||||
Returns:
|
||||
`dict[str, list]`:
|
||||
Example reformatted to ChatML style.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from trl import maybe_convert_to_chatml
|
||||
>>> example = {
|
||||
... "conversations": [
|
||||
... {"from": "user", "value": "What color is the sky?"},
|
||||
... {"from": "assistant", "value": "It is blue."}
|
||||
... ]
|
||||
... }
|
||||
>>> maybe_convert_to_chatml(example)
|
||||
{'messages': [{'role': 'user', 'content': 'What color is the sky?'},
|
||||
{'role': 'assistant', 'content': 'It is blue.'}]}
|
||||
```
|
||||
"""
|
||||
# List of possible keys containing message lists
|
||||
for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]:
|
||||
if key in example and isinstance(example[key], list):
|
||||
messages = example[key]
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
if "from" in message:
|
||||
message["role"] = message.pop("from")
|
||||
if "value" in message:
|
||||
message["content"] = message.pop("value")
|
||||
|
||||
# Rename "conversations" to "messages"
|
||||
if "conversations" in example:
|
||||
example["messages"] = example.pop("conversations")
|
||||
|
||||
return example
|
||||
|
@ -76,7 +76,6 @@ _import_structure = {
|
||||
"disable_dropout_in_model",
|
||||
"empty_cache",
|
||||
"peft_module_casting_to_bf16",
|
||||
"compute_token_accuracy",
|
||||
],
|
||||
"xpo_config": ["XPOConfig"],
|
||||
"xpo_trainer": ["XPOTrainer"],
|
||||
@ -145,7 +144,6 @@ if TYPE_CHECKING:
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
RunningMoments,
|
||||
compute_accuracy,
|
||||
compute_token_accuracy,
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
peft_module_casting_to_bf16,
|
||||
|
@ -495,7 +495,6 @@ class GRPOTrainer(Trainer):
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
|
||||
@ -510,9 +509,13 @@ class GRPOTrainer(Trainer):
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
# Unmerge the adapter to restore the model to its original state.
|
||||
# This must be done after loading weights to ensure they correspond to the merged state.
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
|
||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
device = self.accelerator.device
|
||||
|
@ -43,15 +43,9 @@ from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_liger_kernel_available, is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
|
||||
from .sft_config import SFTConfig
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
compute_token_accuracy,
|
||||
generate_model_card,
|
||||
get_comet_experiment_url,
|
||||
peft_module_casting_to_bf16,
|
||||
)
|
||||
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -107,6 +101,8 @@ class SFTTrainer(Trainer):
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
|
||||
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
||||
@ -186,7 +182,7 @@ class SFTTrainer(Trainer):
|
||||
if peft_config is not None:
|
||||
model = self._prepare_peft_model(model, peft_config, args)
|
||||
|
||||
# 3. Handle the tokenizer
|
||||
# Handle the tokenizer
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
||||
if processing_class.pad_token is None:
|
||||
@ -275,8 +271,10 @@ class SFTTrainer(Trainer):
|
||||
if args.use_liger:
|
||||
if not is_liger_kernel_available():
|
||||
raise ImportError("Please install Liger-kernel for use_liger=True")
|
||||
return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
return model
|
||||
|
||||
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
||||
"""Prepares a model for PEFT training."""
|
||||
@ -368,6 +366,10 @@ class SFTTrainer(Trainer):
|
||||
if isinstance(dataset, ConstantLengthDataset):
|
||||
return dataset
|
||||
|
||||
# If the dataset is already preprocessed (tokenized), skip the processing steps.
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
is_processed = "input_ids" in column_names
|
||||
|
||||
# Build the kwargs for the `map` function
|
||||
map_kwargs = {}
|
||||
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||
@ -375,7 +377,15 @@ class SFTTrainer(Trainer):
|
||||
|
||||
with PartialState().local_main_process_first():
|
||||
# Apply the formatting function if any
|
||||
if formatting_func is not None:
|
||||
if formatting_func is not None and is_processed:
|
||||
warnings.warn(
|
||||
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
|
||||
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
|
||||
"`formatting_func` or pass a dataset that is not already processed.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if formatting_func is not None and not is_processed:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
|
||||
|
||||
@ -395,6 +405,15 @@ class SFTTrainer(Trainer):
|
||||
|
||||
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
||||
|
||||
# Convert the dataset to ChatML if needed
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
|
||||
dataset = dataset.map(
|
||||
maybe_convert_to_chatml,
|
||||
remove_columns="conversations" if "conversations" in dataset.column_names else None,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
# Apply the chat template if needed
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
||||
@ -405,10 +424,19 @@ class SFTTrainer(Trainer):
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
# Tokenize the dataset
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
|
||||
# Tokenize the dataset if needed
|
||||
if not is_processed:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||
|
||||
def tokenize(example, processing_class, dataset_text_field):
|
||||
return processing_class(example[dataset_text_field])
|
||||
|
||||
dataset = dataset.map(
|
||||
tokenize,
|
||||
fn_kwargs={"processing_class": processing_class, "dataset_text_field": args.dataset_text_field},
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
# Pack or truncate
|
||||
if packing:
|
||||
@ -421,10 +449,18 @@ class SFTTrainer(Trainer):
|
||||
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
|
||||
)
|
||||
elif args.max_seq_length is not None:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
|
||||
|
||||
def truncate(example, max_seq_length):
|
||||
return {key: example[key][:max_seq_length] for key in ["input_ids", "attention_mask"]}
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]},
|
||||
truncate,
|
||||
fn_kwargs={"max_seq_length": args.max_seq_length},
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
# For Liger kernel, ensure only input_ids is present
|
||||
if args.use_liger:
|
||||
dataset = dataset.select_columns("input_ids")
|
||||
@ -444,14 +480,24 @@ class SFTTrainer(Trainer):
|
||||
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
||||
shift_labels = inputs["labels"][..., 1:].contiguous()
|
||||
|
||||
# Gather logits and labels from all GPUs first
|
||||
shift_logits = self.accelerator.gather_for_metrics(shift_logits)
|
||||
shift_labels = self.accelerator.gather_for_metrics(shift_labels)
|
||||
# Get predictions
|
||||
predictions = shift_logits.argmax(dim=-1)
|
||||
|
||||
# Then compute accuracy on the gathered tensors
|
||||
if self.accelerator.is_main_process:
|
||||
accuracy = compute_token_accuracy(shift_logits, shift_labels)
|
||||
self._metrics["mean_token_accuracy"].append(accuracy)
|
||||
# Create mask for non-padding tokens (assuming ignore_index is -100)
|
||||
mask = shift_labels != -100
|
||||
|
||||
# Calculate accuracy only on non-padding tokens
|
||||
correct_predictions = (predictions == shift_labels) & mask
|
||||
total_tokens = mask.sum()
|
||||
correct_tokens = correct_predictions.sum()
|
||||
|
||||
# Gather the correct_tokens and total_tokens across all processes
|
||||
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
|
||||
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
|
||||
|
||||
# Compute the mean token accuracy and log it
|
||||
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
|
||||
self._metrics["mean_token_accuracy"].append(accuracy)
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
|
||||
return mask, *tensors
|
||||
|
||||
|
||||
def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
|
||||
"""
|
||||
Compute the mean token accuracy.
|
||||
"""
|
||||
# Get predictions
|
||||
predictions = logits.argmax(dim=-1)
|
||||
|
||||
# Create mask for non-padding tokens (assuming pad_token_id is ignore_index)
|
||||
mask = labels != ignore_index
|
||||
|
||||
# Calculate accuracy only on non-padding tokens
|
||||
correct_predictions = (predictions == labels) & mask
|
||||
total_tokens = mask.sum()
|
||||
correct_tokens = correct_predictions.sum()
|
||||
|
||||
# Calculate accuracy
|
||||
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index):
|
||||
"""
|
||||
A memory-efficient implementation of the common `log_softmax -> gather` operation.
|
||||
|
Reference in New Issue
Block a user