mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
1 Commits
f6e7c200c0
...
gkd-cb
Author | SHA1 | Date | |
---|---|---|---|
72b59bf373 |
@ -45,6 +45,10 @@ class GKDConfig(SFTConfig):
|
||||
seq_kd (`bool`, *optional*, defaults to `False`):
|
||||
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
||||
on teacher-generated output).
|
||||
use_transformers_paged (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the `transformers` paged implementation for generation. If set to `True`, the
|
||||
`transformers` paged implementation will be used for generation instead of the default padded
|
||||
implementation.
|
||||
"""
|
||||
|
||||
temperature: float = field(
|
||||
@ -95,6 +99,23 @@ class GKDConfig(SFTConfig):
|
||||
"FT on teacher-generated output)."
|
||||
},
|
||||
)
|
||||
use_transformers_paged: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the "
|
||||
"`transformers` paged implementation will be used for generation instead of the default padded "
|
||||
"implementation."
|
||||
},
|
||||
)
|
||||
ds3_gather_for_generation: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
||||
"generation, improving generation speed. However, disabling this option allows training models that "
|
||||
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
||||
"is not compatible with vLLM generation."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
@ -15,12 +15,14 @@
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
BaseImageProcessor,
|
||||
@ -36,6 +38,7 @@ from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..extras.profiling import profiling_context
|
||||
from ..models import prepare_deepspeed
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .gkd_config import GKDConfig
|
||||
@ -43,9 +46,9 @@ from .sft_trainer import SFTTrainer
|
||||
from .utils import (
|
||||
DataCollatorForChatML,
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
generate_model_card,
|
||||
get_comet_experiment_url,
|
||||
pad,
|
||||
)
|
||||
|
||||
|
||||
@ -120,7 +123,10 @@ class GKDTrainer(SFTTrainer):
|
||||
if self.is_deepspeed_enabled:
|
||||
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
|
||||
else:
|
||||
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
||||
# For FSDP, we need device_placement=True to ensure proper device placement
|
||||
self.teacher_model = self.accelerator.prepare_model(
|
||||
teacher_model, evaluation_mode=True, device_placement=True
|
||||
)
|
||||
|
||||
self.lmbda = args.lmbda
|
||||
self.beta = args.beta
|
||||
@ -145,6 +151,9 @@ class GKDTrainer(SFTTrainer):
|
||||
):
|
||||
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
# Training arguments
|
||||
self.use_transformers_paged = args.use_transformers_paged
|
||||
|
||||
def _prepare_dataset(self, dataset, *args):
|
||||
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
||||
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
||||
@ -188,9 +197,9 @@ class GKDTrainer(SFTTrainer):
|
||||
else:
|
||||
# Compute the log of the mixture distribution
|
||||
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
|
||||
mixture_log_probs = torch.logsumexp(
|
||||
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
|
||||
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
@ -224,21 +233,35 @@ class GKDTrainer(SFTTrainer):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
# compute teacher output in eval mode
|
||||
# compute teacher output in eval mode with proper FSDP handling
|
||||
self.teacher_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs_teacher = self.teacher_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
# For FSDP, we need to properly handle the teacher model
|
||||
if self.is_fsdp_enabled:
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
# Use FSDP.summon_full_params to ensure all parameters are available
|
||||
with FSDP.summon_full_params(self.teacher_model, recurse=False):
|
||||
outputs_teacher = self.teacher_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
else:
|
||||
# For non-FSDP models, ensure inputs are on the teacher model's device
|
||||
teacher_device = next(self.teacher_model.parameters()).device
|
||||
outputs_teacher = self.teacher_model(
|
||||
input_ids=inputs["input_ids"].to(teacher_device),
|
||||
attention_mask=inputs["attention_mask"].to(teacher_device),
|
||||
)
|
||||
|
||||
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
||||
prompt_lengths = inputs["prompts"].shape[1]
|
||||
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
||||
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
||||
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
||||
# Labels are the completion tokens we want to predict
|
||||
shifted_labels = inputs["input_ids"][:, prompt_lengths:].to(shifted_student_logits.device)
|
||||
|
||||
# compute loss
|
||||
# compute the loss
|
||||
loss = self.generalized_jsd_loss(
|
||||
student_logits=shifted_student_logits,
|
||||
teacher_logits=shifted_teacher_logits,
|
||||
@ -246,24 +269,79 @@ class GKDTrainer(SFTTrainer):
|
||||
beta=self.beta,
|
||||
)
|
||||
|
||||
# empty cache
|
||||
empty_cache()
|
||||
|
||||
# Return loss
|
||||
return (loss, outputs_student) if return_outputs else loss
|
||||
|
||||
@staticmethod
|
||||
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
||||
# Generate output with respect to the prompt only
|
||||
generated_outputs = model.generate(
|
||||
input_ids=inputs["prompts"],
|
||||
attention_mask=inputs.get("prompt_attention_mask", None),
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None):
|
||||
# Import FSDP for both generation paths
|
||||
|
||||
# Generate output with respect to the prompt only
|
||||
if self.use_transformers_paged:
|
||||
# Use paged attention for generation
|
||||
# Extract prompts from inputs - handle both tensor and text formats
|
||||
if "prompts" in inputs:
|
||||
if isinstance(inputs["prompts"], torch.Tensor):
|
||||
# Convert tensor prompts to text
|
||||
prompts_text = self.processing_class.batch_decode(inputs["prompts"], skip_special_tokens=True)
|
||||
else:
|
||||
prompts_text = inputs["prompts"]
|
||||
else:
|
||||
# Fallback: assume input_ids contains the prompts
|
||||
prompts_text = self.processing_class.batch_decode(inputs["input_ids"], skip_special_tokens=True)
|
||||
|
||||
prompt_inputs = self.processing_class(text=prompts_text)
|
||||
generation_config.max_batch_tokens = 512
|
||||
generation_config.num_blocks = 1024
|
||||
generation_config.block_size = 128
|
||||
generation_config.do_sample = False # logit processing issue for now
|
||||
generation_config.max_new_tokens = generation_config.max_new_tokens or 128
|
||||
|
||||
previous_attn = model.config._attn_implementation
|
||||
if torch.cuda.is_available():
|
||||
model.config._attn_implementation = "paged_attention"
|
||||
else:
|
||||
model.config._attn_implementation = "sdpa_paged"
|
||||
|
||||
model.eval()
|
||||
with (
|
||||
profiling_context(self, "transformers.generate_batch"),
|
||||
unwrap_model_for_generation(
|
||||
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
||||
) as unwrapped_model,
|
||||
torch.no_grad(),
|
||||
FSDP.summon_full_params(model, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
unwrapped_model.to(torch.bfloat16)
|
||||
all_outputs = unwrapped_model.generate_batch(
|
||||
prompt_inputs.input_ids, generation_config=generation_config
|
||||
)
|
||||
completion_ids = [output.generated_tokens for output in all_outputs.values()]
|
||||
completion_ids = [torch.tensor(ids, device=unwrapped_model.device) for ids in completion_ids]
|
||||
completion_ids = pad(completion_ids, padding_value=pad_token_id)
|
||||
prompt_ids = [torch.tensor(ids, device=unwrapped_model.device) for ids in prompt_inputs.input_ids]
|
||||
prompt_ids = pad(prompt_ids, padding_value=pad_token_id)
|
||||
generated_tokens = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
|
||||
model.config._attn_implementation = previous_attn
|
||||
model.train()
|
||||
else:
|
||||
# Regular generation path with proper FSDP handling
|
||||
model.eval()
|
||||
with (
|
||||
unwrap_model_for_generation(
|
||||
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
||||
) as unwrapped_model,
|
||||
torch.no_grad(),
|
||||
FSDP.summon_full_params(model, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
generated_outputs = unwrapped_model.generate(
|
||||
input_ids=inputs["prompts"],
|
||||
attention_mask=inputs.get("prompt_attention_mask", None),
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
generated_tokens = generated_outputs.sequences
|
||||
model.train()
|
||||
|
||||
# Get the generated token IDs
|
||||
generated_tokens = generated_outputs.sequences
|
||||
# Calculate new attention mask
|
||||
new_attention_mask = torch.ones_like(generated_tokens)
|
||||
new_labels = generated_tokens.clone()
|
||||
|
Reference in New Issue
Block a user