mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
25 Commits
f6e7c200c0
...
liger-orpo
Author | SHA1 | Date | |
---|---|---|---|
83f75601a8 | |||
f0eb1af293 | |||
f2ed765da1 | |||
9c317a5afb | |||
851ff26675 | |||
e3b4731054 | |||
ac2328be1b | |||
f6ffbf6bb1 | |||
4861e8f4ff | |||
5ee37a6134 | |||
e1918b77b9 | |||
5fae1b25e9 | |||
5c6744ff47 | |||
f4979b0f15 | |||
568e21a27f | |||
6f7918fbdb | |||
aa3c3b7eda | |||
5776a4e354 | |||
afaf5a86d0 | |||
b3f3270377 | |||
220f7541d6 | |||
c383bf6116 | |||
7682e31926 | |||
44aa20c56e | |||
b480fff127 |
@ -17,7 +17,7 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
from transformers.testing_utils import require_liger_kernel, require_peft
|
||||
|
||||
from trl import ORPOConfig, ORPOTrainer
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
@ -179,3 +179,36 @@ class ORPOTrainerTester(TrlTestCase):
|
||||
trainer.train()
|
||||
|
||||
self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
|
||||
|
||||
@require_liger_kernel
|
||||
def test_orpo_trainer_with_liger(self):
|
||||
"""Test ORPO trainer with Liger loss enabled."""
|
||||
|
||||
training_args = ORPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
report_to="none",
|
||||
use_liger_loss=True, # Enable Liger loss
|
||||
)
|
||||
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
processing_class=self.tokenizer,
|
||||
train_dataset=dummy_dataset["train"],
|
||||
eval_dataset=dummy_dataset["test"],
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
@ -63,6 +63,11 @@ class ORPOConfig(TrainingArguments):
|
||||
string.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
use_liger_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Liger loss.
|
||||
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
|
||||
Name of the attribute in the model that contains the base model. This is used to get the base model from
|
||||
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
|
||||
"""
|
||||
|
||||
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
|
||||
@ -162,6 +167,18 @@ class ORPOConfig(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
use_liger_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use Liger loss."},
|
||||
)
|
||||
base_model_attribute_name: str = field(
|
||||
default="model",
|
||||
metadata={
|
||||
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
|
||||
"model from the model when the model does not have a `get_decoder` method in the case when "
|
||||
"`use_liger_loss` is `True`."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
|
||||
|
@ -45,7 +45,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
||||
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from .orpo_config import ORPOConfig
|
||||
@ -66,13 +66,15 @@ from .utils import (
|
||||
if is_peft_available():
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -356,6 +358,15 @@ class ORPOTrainer(Trainer):
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
# Import Liger loss if enabled
|
||||
if self.args.use_liger_loss:
|
||||
if not is_liger_kernel_available():
|
||||
raise ValueError(
|
||||
"You set `use_liger_loss=True` but the liger kernel is not available. "
|
||||
"Please install liger-kernel first: `pip install liger-kernel`"
|
||||
)
|
||||
self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
||||
|
||||
def build_tokenized_answer(self, prompt, answer):
|
||||
"""
|
||||
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
|
||||
@ -730,59 +741,112 @@ class ORPOTrainer(Trainer):
|
||||
if self.aux_loss_enabled:
|
||||
model_kwargs["output_router_logits"] = True
|
||||
|
||||
outputs = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
all_logits = outputs.logits
|
||||
|
||||
def cross_entropy_loss(logits, labels):
|
||||
if not self.is_encoder_decoder:
|
||||
# Shift so that tokens < n predict n
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
labels = labels.view(-1)
|
||||
# Enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits, labels)
|
||||
return loss
|
||||
|
||||
# orpo nll target is with respect to the concatenated prompt + completionlabels
|
||||
if self.is_encoder_decoder:
|
||||
labels = concatenated_batch["concatenated_labels"].clone()
|
||||
else:
|
||||
labels = concatenated_batch["concatenated_input_ids"].clone()
|
||||
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
||||
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
||||
# orpo chosen nll loss is computed over the full prompt and response
|
||||
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=True,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
if self.args.use_liger_loss:
|
||||
if self.is_encoder_decoder:
|
||||
# 1. Get encoder outputs
|
||||
encoder_outputs = model.get_encoder()(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
return_dict=True,
|
||||
)
|
||||
# 2. Get decoder outputs
|
||||
outputs = model.get_decoder()(
|
||||
input_ids=model_kwargs["decoder_input_ids"],
|
||||
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
# skip the lm head and get the last hidden state
|
||||
if hasattr(model, "get_decoder"):
|
||||
base_model = model.get_decoder()
|
||||
else:
|
||||
base_model = getattr(model, self.args.base_model_attribute_name)
|
||||
outputs = base_model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
lm_head = model.get_output_embeddings()
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
# return the final loss and aux_outputs tuple
|
||||
loss, aux_outputs = self.orpo_loss_fn(
|
||||
lm_head.weight,
|
||||
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
|
||||
concatenated_batch["concatenated_labels"][:, 1:]
|
||||
if not self.is_encoder_decoder
|
||||
else concatenated_batch["concatenated_labels"],
|
||||
lm_head.bias if hasattr(lm_head, "bias") else None,
|
||||
nll_target=labels[:, 1:] if not self.is_encoder_decoder else labels,
|
||||
)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
chosen_logits = all_logits[:len_chosen, :-1, :]
|
||||
rejected_logits = all_logits[len_chosen:, :-1, :]
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * outputs.aux_loss
|
||||
|
||||
return loss, aux_outputs
|
||||
else:
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
outputs = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
all_logits = outputs.logits
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
||||
def cross_entropy_loss(logits, labels):
|
||||
if not self.is_encoder_decoder:
|
||||
# Shift so that tokens < n predict n
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
labels = labels.view(-1)
|
||||
# Enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits, labels)
|
||||
return loss
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
||||
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=True,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
chosen_logits = all_logits[:len_chosen, :-1, :]
|
||||
rejected_logits = all_logits[len_chosen:, :-1, :]
|
||||
else:
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
return (
|
||||
chosen_logps,
|
||||
rejected_logps,
|
||||
chosen_logits,
|
||||
rejected_logits,
|
||||
chosen_nll_loss,
|
||||
outputs.aux_loss,
|
||||
)
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
@ -794,48 +858,61 @@ class ORPOTrainer(Trainer):
|
||||
metrics = {}
|
||||
|
||||
forward_output = self.concatenated_forward(model, batch)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = forward_output[:5]
|
||||
if self.aux_loss_enabled:
|
||||
aux_loss = forward_output[5]
|
||||
if self.args.use_liger_loss:
|
||||
# full ORPO loss and aux outputs
|
||||
(
|
||||
loss,
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
chosen_rewards,
|
||||
rejected_rewards,
|
||||
log_odds_ratio,
|
||||
log_odds_chosen,
|
||||
),
|
||||
) = forward_output
|
||||
else:
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = forward_output[:5]
|
||||
if self.aux_loss_enabled:
|
||||
aux_loss = forward_output[5]
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
||||
policy_chosen_logps, policy_rejected_logps
|
||||
)
|
||||
# full ORPO loss
|
||||
loss = policy_nll_loss - losses.mean()
|
||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
||||
policy_chosen_logps, policy_rejected_logps
|
||||
)
|
||||
# full ORPO loss
|
||||
loss = policy_nll_loss - losses.mean()
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * aux_loss
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
||||
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
||||
chosen_rewards - rejected_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
||||
policy_rejected_logits.detach().mean()
|
||||
).mean()
|
||||
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
||||
policy_chosen_logits.detach().mean()
|
||||
).mean()
|
||||
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
||||
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
||||
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
|
||||
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
||||
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
||||
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step() # needed because .item() calls
|
||||
for k, v in metrics.items():
|
||||
metrics[k] = v.item()
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * aux_loss
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
Reference in New Issue
Block a user