Compare commits

...

25 Commits

Author SHA1 Message Date
83f75601a8 Merge branch 'main' into liger-orpo 2025-08-20 10:20:51 -07:00
f0eb1af293 Merge branch 'main' into liger-orpo 2025-08-20 17:18:32 +00:00
f2ed765da1 undo change 2025-02-20 12:21:04 +00:00
9c317a5afb undo change 2025-02-20 12:19:08 +00:00
851ff26675 use fields 2025-02-20 12:13:03 +00:00
e3b4731054 Merge branch 'main' into liger-orpo 2025-02-20 12:50:11 +01:00
ac2328be1b Update setup.py 2025-02-13 14:11:56 +01:00
f6ffbf6bb1 fix enc-dec 2025-01-03 17:02:36 +01:00
4861e8f4ff Merge branch 'main' into liger-orpo 2024-12-29 15:45:23 +01:00
5ee37a6134 call with nll_target 2024-12-29 14:55:11 +01:00
e1918b77b9 add back the orpo nll labels 2024-12-28 16:41:30 +01:00
5fae1b25e9 Merge branch 'main' into liger-orpo 2024-12-23 16:10:39 +01:00
5c6744ff47 call orpo_loss_fn with shifted inputs 2024-12-19 20:42:01 +01:00
f4979b0f15 pass is_enc_dec 2024-12-19 12:02:42 +01:00
568e21a27f add back missing line 2024-12-19 11:38:36 +01:00
6f7918fbdb Merge branch 'main' into liger-orpo 2024-12-19 11:35:10 +01:00
aa3c3b7eda Merge branch 'main' into liger-orpo 2024-12-19 11:09:58 +01:00
5776a4e354 make it a bit more robust 2024-12-17 13:03:02 +01:00
afaf5a86d0 use get_decoder() 2024-12-17 11:56:37 +01:00
b3f3270377 skip the lm_head when use_liger_loss is true 2024-12-17 11:09:38 +01:00
220f7541d6 make import more readable 2024-12-15 17:49:49 +01:00
c383bf6116 update liger version 2024-12-15 17:46:45 +01:00
7682e31926 passing self.args.use_liger_loss without liger installed should raised an error 2024-12-15 16:57:36 +01:00
44aa20c56e Update tests/test_orpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-12-15 16:52:53 +01:00
b480fff127 add native liger-kernl orpo loss 2024-12-15 13:54:00 +01:00
3 changed files with 204 additions and 77 deletions

View File

@ -17,7 +17,7 @@ import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft
from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
@ -179,3 +179,36 @@ class ORPOTrainerTester(TrlTestCase):
trainer.train()
self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
@require_liger_kernel
def test_orpo_trainer_with_liger(self):
"""Test ORPO trainer with Liger loss enabled."""
training_args = ORPOConfig(
output_dir=self.tmp_dir,
report_to="none",
use_liger_loss=True, # Enable Liger loss
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
trainer = ORPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

View File

@ -63,6 +63,11 @@ class ORPOConfig(TrainingArguments):
string.
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
"""
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
@ -162,6 +167,18 @@ class ORPOConfig(TrainingArguments):
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
use_liger_loss: bool = field(
default=False,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
"model from the model when the model does not have a `get_decoder` method in the case when "
"`use_liger_loss` is `True`."
},
)
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

View File

@ -45,7 +45,7 @@ from transformers import (
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .orpo_config import ORPOConfig
@ -66,13 +66,15 @@ from .utils import (
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
logger = logging.get_logger(__name__)
@ -356,6 +358,15 @@ class ORPOTrainer(Trainer):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
# Import Liger loss if enabled
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
@ -730,59 +741,112 @@ class ORPOTrainer(Trainer):
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
# orpo nll target is with respect to the concatenated prompt + completionlabels
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
# orpo chosen nll loss is computed over the full prompt and response
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if self.args.use_liger_loss:
if self.is_encoder_decoder:
# 1. Get encoder outputs
encoder_outputs = model.get_encoder()(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
return_dict=True,
)
# 2. Get decoder outputs
outputs = model.get_decoder()(
input_ids=model_kwargs["decoder_input_ids"],
encoder_hidden_states=encoder_outputs.last_hidden_state,
use_cache=False,
)
else:
# skip the lm head and get the last hidden state
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs = base_model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
# return the final loss and aux_outputs tuple
loss, aux_outputs = self.orpo_loss_fn(
lm_head.weight,
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
nll_target=labels[:, 1:] if not self.is_encoder_decoder else labels,
)
if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
if self.aux_loss_enabled:
loss += self.aux_loss_coef * outputs.aux_loss
return loss, aux_outputs
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
output_hidden_states=False,
**model_kwargs,
)
all_logits = outputs.logits
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
if self.aux_loss_enabled:
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
outputs.aux_loss,
)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
def get_batch_loss_metrics(
self,
@ -794,48 +858,61 @@ class ORPOTrainer(Trainer):
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self.args.use_liger_loss:
# full ORPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
log_odds_ratio,
log_odds_chosen,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics