[NashMD] fix the edge case where the model is a peft model (#3473)

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Kashif Rasul
2025-05-20 17:02:04 +02:00
committed by GitHub
parent e0dd525021
commit a528b9c465
4 changed files with 125 additions and 22 deletions

View File

@ -160,6 +160,38 @@ class TestNashMDTrainer(unittest.TestCase):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])
@require_peft
def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
# self.model from setUp is a base AutoModelForCausalLM
peft_model_instance = get_peft_model(self.model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = NashMDConfig(
output_dir=tmp_dir,
per_device_train_batch_size=1, # Keep small for quick test
max_steps=2, # Few steps
learning_rate=5.0e-7,
eval_strategy="no",
report_to="none",
remove_unused_columns=False, # Important for the dummy dataset
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
trainer = NashMDTrainer(
model=peft_model_instance, # Pass the already PEFT model
ref_model=None, # Implicit reference from peft_model_instance's base
reward_model=self.reward_model, # To trigger GeometricMixtureWrapper path
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
# peft_config is not passed, as model is already PEFT
)
trainer.train()
self.assertIn("train_loss", trainer.state.log_history[-1])
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@require_llm_blender
def test_nash_md_trainer_judge_training(self, config_name):

View File

@ -160,6 +160,36 @@ class TestXPOTrainer(unittest.TestCase):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])
@require_peft
def test_training_pre_pefted_model_implicit_ref(self):
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
peft_model_instance = get_peft_model(self.model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = XPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=1,
max_steps=2,
learning_rate=5.0e-7,
eval_strategy="no",
report_to="none",
remove_unused_columns=False,
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
trainer = XPOTrainer(
model=peft_model_instance,
ref_model=None,
reward_model=self.reward_model, # Using reward_model to ensure _generate_completions is used as expected
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
)
trainer.train()
self.assertIn("train_loss", trainer.state.log_history[-1])
@require_llm_blender
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_xpo_trainer_judge_training(self, config_name):

View File

@ -32,7 +32,7 @@ from transformers import (
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils import is_apex_available
from transformers.utils import is_apex_available, is_peft_available
from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.modeling_base import GeometricMixtureWrapper
@ -59,6 +59,10 @@ if is_wandb_available():
import wandb
if is_peft_available():
from peft import PeftModel
class NashMDTrainer(OnlineDPOTrainer):
r"""
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
@ -170,28 +174,50 @@ class NashMDTrainer(OnlineDPOTrainer):
return self._mixture_coef
def _generate_completions(self, model, prompts):
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
model_output = unwrapped_model.generate(
# Generate completions from the policy model.
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
model_output = unwrapped_policy_for_gen_ctx.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
ref_model = model if self.ref_model is None else self.ref_model
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
mixture_model = GeometricMixtureWrapper(
model=unwrapped_model,
ref_model=unwrapped_ref_model,
generation_config=self.generation_config,
mixture_coef=self.mixture_coef,
device=self.accelerator.device,
)
# Get the DDP/FSDP unwrapped version of the main model.
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
policy_model_for_gmw = self.accelerator.unwrap_model(model)
mixture_output = mixture_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
# Determine the correct reference model for GeometricMixtureWrapper.
# This also needs to be DDP/FSDP unwrapped.
ref_model_for_gmw: torch.nn.Module
if self.ref_model is None:
# No explicit ref_model is provided.
# Use the base of the main `model` if it's a PEFT model.
# policy_model_for_gmw is already DDP-unwrapped.
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
else:
# Not a PEFT model (or PEFT not available), or already a base model.
# Use the DDP-unwrapped policy model itself as the reference.
ref_model_for_gmw = policy_model_for_gmw
else:
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
with torch.no_grad(): # Ensure no_grad context for mixture model generation
mixture_model = GeometricMixtureWrapper(
model=policy_model_for_gmw,
ref_model=ref_model_for_gmw,
generation_config=self.generation_config,
mixture_coef=self.mixture_coef,
device=self.accelerator.device,
)
mixture_output = mixture_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
return model_output, mixture_output

View File

@ -33,6 +33,7 @@ from transformers import (
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available
from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.utils import unwrap_model_for_generation
@ -58,6 +59,10 @@ if is_wandb_available():
import wandb
if is_peft_available():
from peft import PeftModel
class XPOTrainer(OnlineDPOTrainer):
r"""
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
@ -174,16 +179,26 @@ class XPOTrainer(OnlineDPOTrainer):
return self._alpha
def _generate_completions(self, prompts, model):
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
model_output = unwrapped_model.generate(
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
model_output = unwrapped_policy_model_for_gen.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
ref_model = model if self.ref_model is None else self.ref_model
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
ref_output = unwrapped_ref_model.generate(
actual_model_for_ref_generation: torch.nn.Module
if self.ref_model is None:
unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
else:
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
else:
actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
ref_output = final_ref_model_for_gen.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,