[Online-DPO] fix the completion_len == max_new_tokens crash (#4193)

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Kashif Rasul
2025-10-10 17:21:01 +02:00
committed by GitHub
parent 86d1963cc1
commit b997a31981
2 changed files with 35 additions and 24 deletions

View File

@ -412,3 +412,10 @@ class OnlineDPOConfig(TrainingArguments):
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
self.beta = self.beta[0]
if self.max_new_tokens >= self.max_length:
warnings.warn(
f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). "
"This will cause prompts to be truncated or completely removed in the forward pass. "
"To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ",
)

View File

@ -57,8 +57,13 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_cha
from ..extras.profiling import profiling_context
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_peft_model
from ..models.utils import unwrap_model_for_generation
from ..models import (
create_reference_model,
prepare_deepspeed,
prepare_fsdp,
prepare_peft_model,
unwrap_model_for_generation,
)
from .base_trainer import BaseTrainer
from .judges import BasePairwiseJudge
from .online_dpo_config import OnlineDPOConfig
@ -69,7 +74,6 @@ from .utils import (
empty_cache,
ensure_master_addr_port,
pad,
prepare_deepspeed,
truncate_right,
)
@ -588,24 +592,20 @@ class OnlineDPOTrainer(BaseTrainer):
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
self.generation_config = GenerationConfig(**generation_kwargs)
if self.is_deepspeed_enabled:
if self.ref_model is not None:
self.ref_model = prepare_deepspeed(
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
)
# Prepare reward function models for DeepSpeed
if self.reward_funcs is not None:
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.reward_funcs is not None:
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
if self.is_deepspeed_enabled:
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
else:
if self.ref_model is not None:
self.ref_model = self.ref_model.to(self.accelerator.device)
# Prepare reward function models for FSDP/regular training
if self.reward_funcs is not None:
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
# Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
else:
# set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True, device_placement=True
)
@ -833,8 +833,10 @@ class OnlineDPOTrainer(BaseTrainer):
def _generate_vllm_colocate(self, prompts, images=None):
"""Generate completions using vLLM colocate mode"""
# Update model weights if needed
self._move_model_to_vllm()
# Update model weights if needed - only after gradient accumulation completes
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Apply chat template if conversational
if is_conversational({"prompt": prompts[0]}):
@ -1234,10 +1236,12 @@ class OnlineDPOTrainer(BaseTrainer):
# Get the logprobs of the completions from the model
output = model(prompt_completion_ids, **model_kwargs)
# There is 1 offset, because the model predict the next token
# There is 1 offset, because the model predicts the next token
prompt_len = prompt_ids.size(1)
start_idx = prompt_len - 1 if prompt_len > 0 else 0
logits = output.logits[:, start_idx:-1]
# Only slice off the last logit when we have a prompt, otherwise we need all logits
end_idx = -1 if prompt_len > 0 else None
logits = output.logits[:, start_idx:end_idx]
# Take the completion tokens logprob
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)