Compare commits

...

1 Commits

Author SHA1 Message Date
72b59bf373 initial use_transformers_paged 2025-05-28 11:26:47 +00:00
2 changed files with 125 additions and 26 deletions

View File

@ -45,6 +45,10 @@ class GKDConfig(SFTConfig):
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
use_transformers_paged (`bool`, *optional*, defaults to `True`):
Whether to use the `transformers` paged implementation for generation. If set to `True`, the
`transformers` paged implementation will be used for generation instead of the default padded
implementation.
"""
temperature: float = field(
@ -95,6 +99,23 @@ class GKDConfig(SFTConfig):
"FT on teacher-generated output)."
},
)
use_transformers_paged: bool = field(
default=False,
metadata={
"help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the "
"`transformers` paged implementation will be used for generation instead of the default padded "
"implementation."
},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
"is not compatible with vLLM generation."
},
)
def __post_init__(self):
super().__post_init__()

View File

@ -15,12 +15,14 @@
import os
import random
import textwrap
from contextlib import nullcontext
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoModelForCausalLM,
BaseImageProcessor,
@ -36,6 +38,7 @@ from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from ..extras.profiling import profiling_context
from ..models import prepare_deepspeed
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
@ -43,9 +46,9 @@ from .sft_trainer import SFTTrainer
from .utils import (
DataCollatorForChatML,
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
pad,
)
@ -120,7 +123,10 @@ class GKDTrainer(SFTTrainer):
if self.is_deepspeed_enabled:
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
else:
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
# For FSDP, we need device_placement=True to ensure proper device placement
self.teacher_model = self.accelerator.prepare_model(
teacher_model, evaluation_mode=True, device_placement=True
)
self.lmbda = args.lmbda
self.beta = args.beta
@ -145,6 +151,9 @@ class GKDTrainer(SFTTrainer):
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
# Training arguments
self.use_transformers_paged = args.use_transformers_paged
def _prepare_dataset(self, dataset, *args):
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
@ -188,9 +197,9 @@ class GKDTrainer(SFTTrainer):
else:
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
dim=0,
)
@ -224,21 +233,35 @@ class GKDTrainer(SFTTrainer):
attention_mask=inputs["attention_mask"],
)
# compute teacher output in eval mode
# compute teacher output in eval mode with proper FSDP handling
self.teacher_model.eval()
with torch.no_grad():
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# For FSDP, we need to properly handle the teacher model
if self.is_fsdp_enabled:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Use FSDP.summon_full_params to ensure all parameters are available
with FSDP.summon_full_params(self.teacher_model, recurse=False):
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
else:
# For non-FSDP models, ensure inputs are on the teacher model's device
teacher_device = next(self.teacher_model.parameters()).device
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"].to(teacher_device),
attention_mask=inputs["attention_mask"].to(teacher_device),
)
# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]
# Labels are the completion tokens we want to predict
shifted_labels = inputs["input_ids"][:, prompt_lengths:].to(shifted_student_logits.device)
# compute loss
# compute the loss
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
@ -246,24 +269,79 @@ class GKDTrainer(SFTTrainer):
beta=self.beta,
)
# empty cache
empty_cache()
# Return loss
return (loss, outputs_student) if return_outputs else loss
@staticmethod
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
# Generate output with respect to the prompt only
generated_outputs = model.generate(
input_ids=inputs["prompts"],
attention_mask=inputs.get("prompt_attention_mask", None),
generation_config=generation_config,
return_dict_in_generate=True,
)
def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None):
# Import FSDP for both generation paths
# Generate output with respect to the prompt only
if self.use_transformers_paged:
# Use paged attention for generation
# Extract prompts from inputs - handle both tensor and text formats
if "prompts" in inputs:
if isinstance(inputs["prompts"], torch.Tensor):
# Convert tensor prompts to text
prompts_text = self.processing_class.batch_decode(inputs["prompts"], skip_special_tokens=True)
else:
prompts_text = inputs["prompts"]
else:
# Fallback: assume input_ids contains the prompts
prompts_text = self.processing_class.batch_decode(inputs["input_ids"], skip_special_tokens=True)
prompt_inputs = self.processing_class(text=prompts_text)
generation_config.max_batch_tokens = 512
generation_config.num_blocks = 1024
generation_config.block_size = 128
generation_config.do_sample = False # logit processing issue for now
generation_config.max_new_tokens = generation_config.max_new_tokens or 128
previous_attn = model.config._attn_implementation
if torch.cuda.is_available():
model.config._attn_implementation = "paged_attention"
else:
model.config._attn_implementation = "sdpa_paged"
model.eval()
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(model, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
unwrapped_model.to(torch.bfloat16)
all_outputs = unwrapped_model.generate_batch(
prompt_inputs.input_ids, generation_config=generation_config
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=unwrapped_model.device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=pad_token_id)
prompt_ids = [torch.tensor(ids, device=unwrapped_model.device) for ids in prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=pad_token_id)
generated_tokens = torch.cat([prompt_ids, completion_ids], dim=1)
model.config._attn_implementation = previous_attn
model.train()
else:
# Regular generation path with proper FSDP handling
model.eval()
with (
unwrap_model_for_generation(
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(model, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
generated_outputs = unwrapped_model.generate(
input_ids=inputs["prompts"],
attention_mask=inputs.get("prompt_attention_mask", None),
generation_config=generation_config,
return_dict_in_generate=True,
)
generated_tokens = generated_outputs.sequences
model.train()
# Get the generated token IDs
generated_tokens = generated_outputs.sequences
# Calculate new attention mask
new_attention_mask = torch.ones_like(generated_tokens)
new_labels = generated_tokens.clone()