mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Integrate OrpoTrainer with PyTorchXLA for faster step time on TPUs (#2001)
* make Orpotrainer run faster on tpu * less data transfer * train-trl.py * fix * set device_map=auto * add is_torch_xla_available guards * delete file * address comments * make presubmit * Update transformer version in setup.py --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
datasets>=1.17.0
|
||||
torch>=1.4.0
|
||||
tqdm
|
||||
transformers
|
||||
transformers>=4.39.0
|
||||
accelerate
|
||||
peft>=0.3.0
|
||||
tyro>=0.5.7
|
2
setup.py
2
setup.py
@ -63,7 +63,7 @@ __version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
"transformers>=4.31.0",
|
||||
"transformers>=4.39.0",
|
||||
"numpy>=1.18.2;platform_system!='Windows'",
|
||||
"numpy<2;platform_system=='Windows'",
|
||||
"accelerate",
|
||||
|
@ -35,7 +35,7 @@ from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_torch_fx_proxy
|
||||
from transformers.utils import is_torch_fx_proxy, is_torch_xla_available
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from ..models import PreTrainedModelWrapper
|
||||
@ -61,6 +61,9 @@ if is_wandb_available():
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
class ORPOTrainer(Trainer):
|
||||
r"""
|
||||
@ -534,6 +537,16 @@ class ORPOTrainer(Trainer):
|
||||
labels=torch.tensor(batch["chosen_labels"])
|
||||
)
|
||||
|
||||
if is_torch_xla_available():
|
||||
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
||||
for k in batch:
|
||||
if "labels" in k or self.is_encoder_decoder:
|
||||
pad_value = self.label_pad_token_id
|
||||
elif k.endswith("_input_ids"):
|
||||
pad_value = self.padding_value
|
||||
elif k.endswith("_attention_mask"):
|
||||
pad_value = 0
|
||||
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
@ -628,7 +641,7 @@ class ORPOTrainer(Trainer):
|
||||
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()
|
||||
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
||||
|
||||
@staticmethod
|
||||
def get_batch_logps(
|
||||
@ -659,7 +672,7 @@ class ORPOTrainer(Trainer):
|
||||
loss_mask = labels != label_pad_token_id
|
||||
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
labels[labels == label_pad_token_id] = 0
|
||||
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
||||
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
|
||||
@ -774,18 +787,21 @@ class ORPOTrainer(Trainer):
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
|
||||
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
||||
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
||||
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step() # needed because .item() calls
|
||||
for k, v in metrics.items():
|
||||
metrics[k] = v.item()
|
||||
if self.aux_loss_enabled:
|
||||
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss
|
||||
|
||||
|
Reference in New Issue
Block a user