Compare commits

...

10 Commits

Author SHA1 Message Date
fc2b041b58 Release: v0.15.2 2025-02-25 22:27:14 +00:00
d4098d1dec 📌 Pin liger-kernel and vLLM (#2952)
* pin liger-kernel

* style
2025-02-25 22:25:18 +00:00
e437710883 🐯 Fix LigerKernel for SFTTrainer (#2940)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-25 22:22:51 +00:00
51d383efca ♻️ Fix caching in SFT (#2945) 2025-02-25 22:22:19 +00:00
38c33c547f Release: v0.15.1 2025-02-18 14:35:49 +00:00
8aa13d809c 🪂 Don't gather logits in SFT to avoid hanging (#2890)
* Don't gather logits

* Remove unused function and test
2025-02-18 14:31:34 +00:00
99c78f123f 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading (#2873) 2025-02-18 10:25:08 +00:00
2ccca1d0aa 🍟 [SFT] Handles the dataset if it has been preprocessed (#2863)
* return dataset if it's preprocessed

* add is_processed flag variable

* add test

* move test_sft_trainer_directly_with_pretokenized_data to Tester2

* Update sft_trainer.py

* no need for padding and truncation

* minor reorganization

* Update trl/trainer/sft_trainer.py

* let the collator pad

* style

* fix tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 10:23:54 +00:00
7596db89da [SFT] fix check for AutoLigerKernelForCausalLM (#2874)
* fix check for AutoLigerKernelForCausalLM

* fix case where AutoLigerKernelForCausalLM is not defined

* update min liger version

* formatting

* fix win CI
2025-02-17 18:36:43 +00:00
def8e48dce 💬 Add maybe_convert_to_chatml map for conversational datasets in SFT (#2862)
* add back get_formatting_func_from_dataset

* maybe_convert_to_chatml

* maybe_convert_to_chatml before maybe_apply_chat_template map

* remove comment

* test

* desc

* style

* Update trl/data_utils.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-17 18:36:29 +00:00
11 changed files with 259 additions and 120 deletions

View File

@ -12,6 +12,10 @@
[[autodoc]] maybe_apply_chat_template
## maybe_convert_to_chatml
[[autodoc]] maybe_convert_to_chatml
## extract_prompt
[[autodoc]] extract_prompt

View File

@ -71,7 +71,7 @@ To create the package for PyPI.
from setuptools import find_packages, setup
__version__ = "0.15.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.15.2" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"accelerate>=0.34.0",
@ -85,13 +85,16 @@ EXTRAS = {
"diffusers": ["diffusers>=0.18.0"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
# can be set to >=0.5.3 when https://github.com/linkedin/Liger-Kernel/issues/586 is fixed
"liger": ["liger-kernel==0.5.3; sys_platform != 'win32'"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
"vllm": ["vllm>=0.7.1; sys_platform != 'win32'"], # vllm is not available on Windows
# vllm is not available on Windows
# vllm 0.7.3 causes hanging while gathering. temporary pinning the version until the issue is resolved
"vllm": ["vllm==0.7.2; sys_platform != 'win32'"],
"vlm": ["Pillow"],
}
EXTRAS["dev"] = []

View File

@ -24,6 +24,7 @@ from trl.data_utils import (
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
@ -435,6 +436,51 @@ class TestPackExamples(unittest.TestCase):
self.assertEqual(dataset.to_dict(), expected_output)
class TestMaybeConvertToChatML(unittest.TestCase):
def test_with_conversations_key(self):
# Particular case where the key is "conversations": we rename it to "messages"
example = {
"conversations": [
{"from": "user", "value": "What color is the sky?"},
{"from": "assistant", "value": "It is blue."},
]
}
expected_output = {
"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
}
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
def test_without_conversations_key(self):
# Same as before, but we don't rename the keys
example = {
"prompt": [{"from": "user", "value": "What color is the sky?"}],
"completion": [{"from": "assistant", "value": "It is blue."}],
}
expected_output = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
}
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
def test_not_conversional(self):
# When not needed, the example should remain unchanged
example = {"text": "The sky is blue."}
self.assertEqual(maybe_convert_to_chatml(example), example)
def test_already_chatml(self):
# When the example is already in ChatML format, it should remain unchanged
example = {
"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
}
self.assertEqual(maybe_convert_to_chatml(example), example)
# Run the tests
if __name__ == "__main__":
unittest.main()

View File

@ -288,7 +288,7 @@ class SFTTrainerTester(unittest.TestCase):
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
def test_sft_trainer_with_pretokenzied_data_packing(self):
def test_sft_trainer_with_pretokenized_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
@ -1370,3 +1370,65 @@ class SFTTrainerTester2(unittest.TestCase):
"base_layer" not in n
): # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_non_chatml_conversational_data(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
# Rename role/content to from/value to ensure SFT works with non-chatML conversational data
def rename_fields(example: list[dict]):
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}
dataset = dataset.map(rename_fields, remove_columns="messages")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_sft_trainer_with_pretokenized_data(self):
# Get the model and dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
def tokenize_example(example):
return tokenizer(example["text"])
# Apply tokenization
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

View File

@ -27,7 +27,6 @@ from trl.trainer import compute_accuracy
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
compute_token_accuracy,
decode_and_strip_padding,
flush_left,
generate_model_card,
@ -456,60 +455,6 @@ class TestFlushLeft(unittest.TestCase):
self.assertTrue(torch.equal(new_mask, expected_mask))
class TestComputeTokenAccuracy(unittest.TestCase):
def test_basic_accuracy(self):
# Test basic accuracy computation
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2]
labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2]
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens
def test_with_ignore_index(self):
# Test accuracy computation with ignored tokens
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]])
labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored
accuracy = compute_token_accuracy(logits, labels, ignore_index=-100)
self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens
def test_all_ignored(self):
# Test case where all tokens are ignored
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[-100, -100]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy
def test_perfect_accuracy(self):
# Test case with 100% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[1, 0]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 1.0) # All predictions correct
def test_zero_accuracy(self):
# Test case with 0% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[0, 1]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # All predictions wrong
def test_batch_accuracy(self):
# Test accuracy computation across multiple batches
logits = torch.tensor(
[
[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1
[[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2
]
)
labels = torch.tensor(
[
[1, 0, 1], # Batch 1
[1, 0, -100], # Batch 2 (last token ignored)
]
)
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.8)
class TestSelectiveLogSoftmax(unittest.TestCase):
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
def test_selective_log_softmax(self, dtype):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.15.0"
__version__ = "0.15.2"
from typing import TYPE_CHECKING
@ -26,6 +26,7 @@ _import_structure = {
"extract_prompt",
"is_conversational",
"maybe_apply_chat_template",
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_examples",
@ -101,7 +102,7 @@ _import_structure = {
"XPOTrainer",
],
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
}
try:
@ -126,6 +127,7 @@ if TYPE_CHECKING:
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
@ -202,7 +204,7 @@ if TYPE_CHECKING:
XPOTrainer,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
try:
if not is_diffusers_available():

View File

@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
dataset type.
Returns:
`bool`: `True` if the data is in a conversational format, `False` otherwise.
`bool`:
`True` if the data is in a conversational format, `False` otherwise.
Examples:
@ -185,20 +186,21 @@ def maybe_apply_chat_template(
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
messages, where each message is a dictionary with keys `"role"` and `"content"`.
tokenizer (`PreTrainedTokenizer`):
The tokenizer to apply the chat template with.
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`):
A list of tools (callable functions) that will be accessible to the model.
If the template does not support function calling, this argument will have no effect
Returns:
`dict[str, str]`: The formatted example with the chat template applied.
`dict[str, str]`:
Formatted example with the chat template applied.
Notes:
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
`"text"`.
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced
by `"text"`.
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
if the last role is `"assistant"`, the final message is continued.
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt.
Else, if the last role is `"assistant"`, the final message is continued.
Example:
@ -462,3 +464,52 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
# Split the values into chunks of size seq_length
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
return examples
def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
"""
Convert a conversational dataset with fields `from` and `value` to ChatML format.
This function modifies conversational data to align with OpenAI's ChatML format:
- Replaces the key `"from"` with `"role"` in message dictionaries.
- Replaces the key `"value"` with `"content"` in message dictionaries.
- Renames `"conversations"` to `"messages"` for consistency with ChatML.
Args:
example (`dict[str, list]`):
A single data entry containing a list of messages.
Returns:
`dict[str, list]`:
Example reformatted to ChatML style.
Example:
```python
>>> from trl import maybe_convert_to_chatml
>>> example = {
... "conversations": [
... {"from": "user", "value": "What color is the sky?"},
... {"from": "assistant", "value": "It is blue."}
... ]
... }
>>> maybe_convert_to_chatml(example)
{'messages': [{'role': 'user', 'content': 'What color is the sky?'},
{'role': 'assistant', 'content': 'It is blue.'}]}
```
"""
# List of possible keys containing message lists
for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]:
if key in example and isinstance(example[key], list):
messages = example[key]
for message in messages:
if isinstance(message, dict):
if "from" in message:
message["role"] = message.pop("from")
if "value" in message:
message["content"] = message.pop("value")
# Rename "conversations" to "messages"
if "conversations" in example:
example["messages"] = example.pop("conversations")
return example

View File

@ -76,7 +76,6 @@ _import_structure = {
"disable_dropout_in_model",
"empty_cache",
"peft_module_casting_to_bf16",
"compute_token_accuracy",
],
"xpo_config": ["XPOConfig"],
"xpo_trainer": ["XPOTrainer"],
@ -145,7 +144,6 @@ if TYPE_CHECKING:
DataCollatorForCompletionOnlyLM,
RunningMoments,
compute_accuracy,
compute_token_accuracy,
disable_dropout_in_model,
empty_cache,
peft_module_casting_to_bf16,

View File

@ -495,7 +495,6 @@ class GRPOTrainer(Trainer):
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
@ -510,9 +509,13 @@ class GRPOTrainer(Trainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
# Unmerge the adapter to restore the model to its original state.
# This must be done after loading weights to ensure they correspond to the merged state.
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device

View File

@ -43,15 +43,9 @@ from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.utils.deprecation import deprecate_kwarg
from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from .sft_config import SFTConfig
from .utils import (
ConstantLengthDataset,
compute_token_accuracy,
generate_model_card,
get_comet_experiment_url,
peft_module_casting_to_bf16,
)
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16
if is_peft_available():
@ -107,6 +101,8 @@ class SFTTrainer(Trainer):
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
@ -186,7 +182,7 @@ class SFTTrainer(Trainer):
if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args)
# 3. Handle the tokenizer
# Handle the tokenizer
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
if processing_class.pad_token is None:
@ -275,8 +271,10 @@ class SFTTrainer(Trainer):
if args.use_liger:
if not is_liger_kernel_available():
raise ImportError("Please install Liger-kernel for use_liger=True")
return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
@ -368,6 +366,10 @@ class SFTTrainer(Trainer):
if isinstance(dataset, ConstantLengthDataset):
return dataset
# If the dataset is already preprocessed (tokenized), skip the processing steps.
column_names = list(next(iter(dataset)).keys())
is_processed = "input_ids" in column_names
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
@ -375,7 +377,15 @@ class SFTTrainer(Trainer):
with PartialState().local_main_process_first():
# Apply the formatting function if any
if formatting_func is not None:
if formatting_func is not None and is_processed:
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
"`formatting_func` or pass a dataset that is not already processed.",
UserWarning,
)
if formatting_func is not None and not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
@ -395,6 +405,15 @@ class SFTTrainer(Trainer):
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
# Convert the dataset to ChatML if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
dataset = dataset.map(
maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in dataset.column_names else None,
**map_kwargs,
)
# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
@ -405,10 +424,19 @@ class SFTTrainer(Trainer):
**map_kwargs,
)
# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
# Tokenize the dataset if needed
if not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
def tokenize(example, processing_class, dataset_text_field):
return processing_class(example[dataset_text_field])
dataset = dataset.map(
tokenize,
fn_kwargs={"processing_class": processing_class, "dataset_text_field": args.dataset_text_field},
**map_kwargs,
)
# Pack or truncate
if packing:
@ -421,10 +449,18 @@ class SFTTrainer(Trainer):
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
)
elif args.max_seq_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
def truncate(example, max_seq_length):
return {key: example[key][:max_seq_length] for key in ["input_ids", "attention_mask"]}
dataset = dataset.map(
lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]},
truncate,
fn_kwargs={"max_seq_length": args.max_seq_length},
**map_kwargs,
)
# For Liger kernel, ensure only input_ids is present
if args.use_liger:
dataset = dataset.select_columns("input_ids")
@ -444,14 +480,24 @@ class SFTTrainer(Trainer):
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()
# Gather logits and labels from all GPUs first
shift_logits = self.accelerator.gather_for_metrics(shift_logits)
shift_labels = self.accelerator.gather_for_metrics(shift_labels)
# Get predictions
predictions = shift_logits.argmax(dim=-1)
# Then compute accuracy on the gathered tensors
if self.accelerator.is_main_process:
accuracy = compute_token_accuracy(shift_logits, shift_labels)
self._metrics["mean_token_accuracy"].append(accuracy)
# Create mask for non-padding tokens (assuming ignore_index is -100)
mask = shift_labels != -100
# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == shift_labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()
# Gather the correct_tokens and total_tokens across all processes
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
# Compute the mean token accuracy and log it
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
self._metrics["mean_token_accuracy"].append(accuracy)
return (loss, outputs) if return_outputs else loss

View File

@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
return mask, *tensors
def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
"""
Compute the mean token accuracy.
"""
# Get predictions
predictions = logits.argmax(dim=-1)
# Create mask for non-padding tokens (assuming pad_token_id is ignore_index)
mask = labels != ignore_index
# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()
# Calculate accuracy
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0
return accuracy
def selective_log_softmax(logits, index):
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.