mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
[Online-DPO] fix the completion_len == max_new_tokens crash (#4193)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
@ -412,3 +412,10 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
|
||||
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
|
||||
self.beta = self.beta[0]
|
||||
|
||||
if self.max_new_tokens >= self.max_length:
|
||||
warnings.warn(
|
||||
f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). "
|
||||
"This will cause prompts to be truncated or completely removed in the forward pass. "
|
||||
"To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ",
|
||||
)
|
||||
|
@ -57,8 +57,13 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_cha
|
||||
from ..extras.profiling import profiling_context
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_vllm_available
|
||||
from ..models import create_reference_model, prepare_peft_model
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..models import (
|
||||
create_reference_model,
|
||||
prepare_deepspeed,
|
||||
prepare_fsdp,
|
||||
prepare_peft_model,
|
||||
unwrap_model_for_generation,
|
||||
)
|
||||
from .base_trainer import BaseTrainer
|
||||
from .judges import BasePairwiseJudge
|
||||
from .online_dpo_config import OnlineDPOConfig
|
||||
@ -69,7 +74,6 @@ from .utils import (
|
||||
empty_cache,
|
||||
ensure_master_addr_port,
|
||||
pad,
|
||||
prepare_deepspeed,
|
||||
truncate_right,
|
||||
)
|
||||
|
||||
@ -588,24 +592,20 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
self.generation_config = GenerationConfig(**generation_kwargs)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = prepare_deepspeed(
|
||||
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
||||
)
|
||||
# Prepare reward function models for DeepSpeed
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if self.is_deepspeed_enabled:
|
||||
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
|
||||
else:
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = self.ref_model.to(self.accelerator.device)
|
||||
# Prepare reward function models for FSDP/regular training
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
# Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
|
||||
else:
|
||||
# set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(
|
||||
reward_func, evaluation_mode=True, device_placement=True
|
||||
)
|
||||
@ -833,8 +833,10 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
|
||||
def _generate_vllm_colocate(self, prompts, images=None):
|
||||
"""Generate completions using vLLM colocate mode"""
|
||||
# Update model weights if needed
|
||||
self._move_model_to_vllm()
|
||||
# Update model weights if needed - only after gradient accumulation completes
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Apply chat template if conversational
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
@ -1234,10 +1236,12 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
# Get the logprobs of the completions from the model
|
||||
output = model(prompt_completion_ids, **model_kwargs)
|
||||
|
||||
# There is 1 offset, because the model predict the next token
|
||||
# There is 1 offset, because the model predicts the next token
|
||||
prompt_len = prompt_ids.size(1)
|
||||
start_idx = prompt_len - 1 if prompt_len > 0 else 0
|
||||
logits = output.logits[:, start_idx:-1]
|
||||
# Only slice off the last logit when we have a prompt, otherwise we need all logits
|
||||
end_idx = -1 if prompt_len > 0 else None
|
||||
logits = output.logits[:, start_idx:end_idx]
|
||||
|
||||
# Take the completion tokens logprob
|
||||
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
||||
|
Reference in New Issue
Block a user