Integrate OrpoTrainer with PyTorchXLA for faster step time on TPUs (#2001)

* make Orpotrainer run faster on tpu

* less data transfer

* train-trl.py

* fix

* set device_map=auto

* add is_torch_xla_available guards

* delete file

* address comments

* make presubmit

* Update transformer version in setup.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
wenxindongwork
2024-09-11 06:11:28 -07:00
committed by GitHub
parent 37934d70a9
commit e2966c8d99
3 changed files with 31 additions and 15 deletions

View File

@ -1,7 +1,7 @@
datasets>=1.17.0
torch>=1.4.0
tqdm
transformers
transformers>=4.39.0
accelerate
peft>=0.3.0
tyro>=0.5.7

View File

@ -63,7 +63,7 @@ __version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc
REQUIRED_PKGS = [
"torch>=1.4.0",
"transformers>=4.31.0",
"transformers>=4.39.0",
"numpy>=1.18.2;platform_system!='Windows'",
"numpy<2;platform_system=='Windows'",
"accelerate",

View File

@ -35,7 +35,7 @@ from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_torch_fx_proxy
from transformers.utils import is_torch_fx_proxy, is_torch_xla_available
from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper
@ -61,6 +61,9 @@ if is_wandb_available():
if is_deepspeed_available():
import deepspeed
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
class ORPOTrainer(Trainer):
r"""
@ -534,6 +537,16 @@ class ORPOTrainer(Trainer):
labels=torch.tensor(batch["chosen_labels"])
)
if is_torch_xla_available():
# Pad the sequences to global max_length to avoid TorchXLA recompilation
for k in batch:
if "labels" in k or self.is_encoder_decoder:
pad_value = self.label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = self.padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
return batch
@staticmethod
@ -628,7 +641,7 @@ class ORPOTrainer(Trainer):
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
@staticmethod
def get_batch_logps(
@ -659,7 +672,7 @@ class ORPOTrainer(Trainer):
loss_mask = labels != label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0
labels = torch.where(labels == label_pad_token_id, 0, labels)
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
@ -774,18 +787,21 @@ class ORPOTrainer(Trainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss