diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 67dfa3b25..8b4a9472c 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -412,3 +412,10 @@ class OnlineDPOConfig(TrainingArguments): if hasattr(self.beta, "__len__") and len(self.beta) == 1: self.beta = self.beta[0] + + if self.max_new_tokens >= self.max_length: + warnings.warn( + f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). " + "This will cause prompts to be truncated or completely removed in the forward pass. " + "To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ", + ) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 40bad5a69..581cd9fed 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -57,8 +57,13 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_cha from ..extras.profiling import profiling_context from ..extras.vllm_client import VLLMClient from ..import_utils import is_vllm_available -from ..models import create_reference_model, prepare_peft_model -from ..models.utils import unwrap_model_for_generation +from ..models import ( + create_reference_model, + prepare_deepspeed, + prepare_fsdp, + prepare_peft_model, + unwrap_model_for_generation, +) from .base_trainer import BaseTrainer from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig @@ -69,7 +74,6 @@ from .utils import ( empty_cache, ensure_master_addr_port, pad, - prepare_deepspeed, truncate_right, ) @@ -588,24 +592,20 @@ class OnlineDPOTrainer(BaseTrainer): generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} self.generation_config = GenerationConfig(**generation_kwargs) - if self.is_deepspeed_enabled: - if self.ref_model is not None: - self.ref_model = prepare_deepspeed( - self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - # Prepare reward function models for DeepSpeed - if self.reward_funcs is not None: - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, PreTrainedModel): + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.reward_funcs is not None: + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - else: - if self.ref_model is not None: - self.ref_model = self.ref_model.to(self.accelerator.device) - # Prepare reward function models for FSDP/regular training - if self.reward_funcs is not None: - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, PreTrainedModel): - # Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp + else: + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp self.reward_funcs[i] = self.accelerator.prepare_model( reward_func, evaluation_mode=True, device_placement=True ) @@ -833,8 +833,10 @@ class OnlineDPOTrainer(BaseTrainer): def _generate_vllm_colocate(self, prompts, images=None): """Generate completions using vLLM colocate mode""" - # Update model weights if needed - self._move_model_to_vllm() + # Update model weights if needed - only after gradient accumulation completes + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step # Apply chat template if conversational if is_conversational({"prompt": prompts[0]}): @@ -1234,10 +1236,12 @@ class OnlineDPOTrainer(BaseTrainer): # Get the logprobs of the completions from the model output = model(prompt_completion_ids, **model_kwargs) - # There is 1 offset, because the model predict the next token + # There is 1 offset, because the model predicts the next token prompt_len = prompt_ids.size(1) start_idx = prompt_len - 1 if prompt_len > 0 else 0 - logits = output.logits[:, start_idx:-1] + # Only slice off the last logit when we have a prompt, otherwise we need all logits + end_idx = -1 if prompt_len > 0 else None + logits = output.logits[:, start_idx:end_idx] # Take the completion tokens logprob logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)