Compare commits

...

26 Commits

Author SHA1 Message Date
db977f63ef Merge branch 'main' into gkd-vllm 2025-07-04 15:15:18 +02:00
0e319cba26 Merge branch 'main' into gkd-vllm 2025-07-01 13:16:59 +02:00
28163312ef Merge branch 'gkd-vllm' of https://github.com/huggingface/trl into gkd-vllm 2025-06-26 16:24:14 +00:00
053edb761d use callback to sync weights 2025-06-26 16:23:49 +00:00
e0fb8e638e Merge branch 'main' into gkd-vllm 2025-06-26 17:25:32 +02:00
88a064f797 revert back 2025-06-26 14:04:27 +00:00
f1f41e3195 remove unused 2025-06-26 13:25:34 +00:00
5229d816af fix deepspeed 3 issue 2025-06-26 13:24:46 +00:00
a8c9f8238b Merge branch 'main' into gkd-vllm 2025-06-23 12:56:44 +02:00
98eaea96b4 set the teacher model's embedding size to that of student 2025-06-09 06:44:15 +00:00
d16579624e use is_peft_model 2025-06-05 09:25:43 +00:00
37a5fcd7ca do not fail silently 2025-06-05 09:15:47 +00:00
61fdb52c9c Merge branch 'main' into gkd-vllm 2025-06-05 11:05:52 +02:00
45c39c06b7 fix doc 2025-06-04 10:39:32 +00:00
0f50383916 fix collocation based sampling 2025-06-04 10:33:24 +00:00
9427ac3379 add doc about vllm 2025-06-03 13:03:44 +00:00
4835a3328d sync after student_vllm_sync_frequency 2025-06-03 08:42:11 +00:00
382f92257f add back _move_student_model_to_vllm 2025-06-02 14:41:26 +00:00
0c86e6cde6 Merge branch 'main' into gkd-vllm 2025-06-02 15:28:11 +02:00
55ec89bebe remove vllm for teacher 2025-05-28 14:43:52 +00:00
f37f0ada91 need the model name_or_path for vllm 2025-05-26 15:51:10 +00:00
9a969d9c93 helper to generate from model 2025-05-26 15:05:38 +00:00
788a4a4044 Merge branch 'main' into gkd-vllm 2025-05-26 16:36:53 +02:00
d09b36a754 Merge branch 'main' into gkd-vllm 2025-05-26 10:24:16 +02:00
f665c336f2 update imports 2025-05-21 18:43:08 +00:00
c35c5df4ab initial 2025-05-21 13:21:01 +00:00
3 changed files with 436 additions and 15 deletions

View File

@ -21,7 +21,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `seq_kd`: controls whether to perform Sequence-Level KD which can be viewed as supervised FT on teacher-generated outputs. When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
@ -89,6 +89,17 @@ The dataset should be formatted as a list of "messages" where each message is a
* `content`: the message content
## Accelerated Generation with vLLM
GKD training supports accelerated student model generation using [vLLM](https://github.com/vllm-project/vllm), which can significantly speed up the on-policy data generation process. Two integration modes are available:
**Server Mode (`student_vllm_mode="server"`)**: In this mode, vLLM runs as a separate server process on dedicated GPUs, and the trainer communicates with it via HTTP. This approach provides good isolation between training and inference but requires additional GPU resources for the vLLM server.
**Co-locate Mode (`student_vllm_mode="colocate"`)**: In this mode, vLLM runs within the same distributed process group as the training job, sharing the same GPUs. This approach maximizes GPU utilization by allowing training and inference to take turns on the same hardware, eliminating idle GPU time and reducing the total number of GPUs required. Co-locate mode typically provides better throughput and is more resource-efficient.
To enable vLLM integration, set `student_use_vllm=True` in your [`GKDConfig`] and configure the appropriate mode. For co-locate mode, adjust `student_vllm_gpu_memory_utilization` (recommended: 0.3 for smaller models) and `student_vllm_tensor_parallel_size` based on your model size and available resources. The `student_vllm_sync_frequency` parameter controls how often the student model weights are synchronized to the vLLM engine (default: every step).
## GKDTrainer
[[autodoc]] GKDTrainer

View File

@ -50,6 +50,26 @@ class GKDConfig(SFTConfig):
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
teacher-generated output).
student_use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed.
student_vllm_mode (`str`, *optional*, defaults to `"server"`):
Mode for student vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or
`"colocate"` (run vLLM in the same process).
student_vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host of the vLLM server for the student model (if `student_vllm_mode="server"`).
student_vllm_server_port (`int`, *optional*, defaults to `8001`):
Port of the vLLM server for the student model (if `student_vllm_mode="server"`).
student_vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
Timeout for connecting to the student vLLM server (if `student_vllm_mode="server"`).
student_vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
GPU memory utilization for the colocated student vLLM engine (if `student_vllm_mode="colocate"`).
It is recommended to set this to a low value if the student and teacher models share the same GPU.
student_vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
Tensor parallel size for the colocated student vLLM engine (if `student_vllm_mode="colocate"`).
student_vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding for the student model.
student_vllm_sync_frequency (`int`, *optional*, defaults to `1`):
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step.
"""
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
@ -103,6 +123,54 @@ class GKDConfig(SFTConfig):
},
)
# VLLM parameters for student model
student_use_vllm: bool = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed."
},
)
student_vllm_mode: str = field(
default="server",
metadata={
"help": 'Mode for student vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).'
},
)
student_vllm_server_host: str = field(
default="0.0.0.0",
metadata={"help": 'Host of the vLLM server for the student model (if `student_vllm_mode="server"`).'},
)
student_vllm_server_port: int = field(
default=8001,
metadata={"help": 'Port of the vLLM server for the student model (if `student_vllm_mode="server"`).'},
)
student_vllm_server_timeout: float = field(
default=240.0,
metadata={"help": 'Timeout for connecting to the student vLLM server (if `student_vllm_mode="server"`).'},
)
student_vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": 'GPU memory utilization for the colocated student vLLM engine (if `student_vllm_mode="colocate"`). It is recommended to set this to a low value if the student and teacher models share the same GPU.'
},
)
student_vllm_tensor_parallel_size: int = field(
default=1,
metadata={
"help": 'Tensor parallel size for the colocated student vLLM engine (if `student_vllm_mode="colocate"`).'
},
)
student_vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding for the student model."},
)
student_vllm_sync_frequency: int = field(
default=1,
metadata={
"help": "Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step."
},
)
def __post_init__(self):
super().__post_init__()
# check lmbda and beta are in the range [0, 1]
@ -110,3 +178,10 @@ class GKDConfig(SFTConfig):
raise ValueError("lmbda must be in the range [0.0, 1.0].")
if self.beta < 0.0 or self.beta > 1.0:
raise ValueError("beta must be in the range [0.0, 1.0].")
# Validate that max_length is sufficient for max_new_tokens
if self.max_length is not None and self.max_new_tokens >= self.max_length:
raise ValueError(
f"max_new_tokens ({self.max_new_tokens}) must be smaller than max_length ({self.max_length}) "
f"to leave room for the prompt. Consider increasing max_length or reducing max_new_tokens."
)

View File

@ -15,12 +15,15 @@
import os
import random
import textwrap
from contextlib import nullcontext
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.utils import broadcast_object_list, gather_object, is_peft_model
from datasets import Dataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoModelForCausalLM,
BaseImageProcessor,
@ -32,10 +35,12 @@ from transformers import (
ProcessorMixin,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_vllm_available
from ..models import prepare_deepspeed
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
@ -55,6 +60,33 @@ if is_peft_available():
if is_wandb_available():
import wandb
if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
class GKDStudentVLLMSyncCallback(TrainerCallback):
"""
Callback to sync student model weights to vLLM after training steps.
This ensures weight syncing happens when DeepSpeed is in a stable state.
"""
def __init__(self, trainer):
self.trainer = trainer
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
"""Sync weights after training step when DeepSpeed is stable."""
if (
self.trainer.student_use_vllm
and state.global_step != self.trainer._last_student_sync_step
and state.global_step % self.trainer.student_vllm_sync_frequency == 0
):
# Check if this is a step where gradients are synchronized
# This happens at the end of gradient accumulation cycles
if hasattr(self.trainer.accelerator, "sync_gradients") and self.trainer.accelerator.sync_gradients:
self.trainer._move_student_model_to_vllm()
self.trainer._last_student_sync_step = state.global_step
class GKDTrainer(SFTTrainer):
_tag_names = ["trl", "gkd"]
@ -80,6 +112,7 @@ class GKDTrainer(SFTTrainer):
# add remove_unused_columns=False to the dataclass args
args.remove_unused_columns = False
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path
super().__init__(
model,
@ -117,6 +150,9 @@ class GKDTrainer(SFTTrainer):
if args.disable_dropout:
disable_dropout_in_model(self.model)
# resize the teacher's token_embeddings to the student's vocab size:
teacher_model.resize_token_embeddings(self.model.config.vocab_size)
if self.is_deepspeed_enabled:
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
else:
@ -145,6 +181,74 @@ class GKDTrainer(SFTTrainer):
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
# vLLM setup for student model if enabled
self.student_use_vllm = args.student_use_vllm
if self.student_use_vllm:
if not is_vllm_available():
raise ImportError(
"vLLM is not available and student_use_vllm is set to True. Please install vLLM with "
"`pip install vllm` to use it."
)
self.student_vllm_mode = args.student_vllm_mode
if self.student_vllm_mode == "server":
if self.accelerator.is_main_process:
self.student_vllm_client = VLLMClient(
host=args.student_vllm_server_host,
server_port=args.student_vllm_server_port,
connection_timeout=args.student_vllm_server_timeout,
)
self.student_vllm_client.init_communicator()
elif self.student_vllm_mode == "colocate":
student_model_name_or_path = self.model_name_or_path
# Check tensor parallel size constraints (same as GRPO)
if args.student_vllm_tensor_parallel_size > 1:
# Make sure tensor_parallel_size divides world size evenly
if not self.accelerator.num_processes % args.student_vllm_tensor_parallel_size == 0:
raise ValueError(
f"student_vllm_tensor_parallel_size ({args.student_vllm_tensor_parallel_size}) must divide world size "
f"({self.accelerator.num_processes}) evenly."
)
# Create subgroups of ranks for TP
self.student_tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
[
list(
range(
i * args.student_vllm_tensor_parallel_size,
(i + 1) * args.student_vllm_tensor_parallel_size,
)
)
for i in range(self.accelerator.num_processes // args.student_vllm_tensor_parallel_size)
]
)
self.student_llm = LLM(
model=student_model_name_or_path,
tensor_parallel_size=args.student_vllm_tensor_parallel_size,
gpu_memory_utilization=args.student_vllm_gpu_memory_utilization,
# Max num seqs can be a small number as we generate one by one during training
max_num_seqs=self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps,
max_model_len=args.max_length, # Assuming max_length covers prompt + new tokens
distributed_executor_backend="external_launcher",
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=args.seed
if args.student_vllm_tensor_parallel_size == 1
else self.accelerator.process_index // args.student_vllm_tensor_parallel_size,
)
# Synchronize all processes after vLLM initialization to prevent hanging
self.accelerator.wait_for_everyone()
else:
raise ValueError(f"Unknown student_vllm_mode: {self.student_vllm_mode}")
self.student_vllm_guided_decoding_regex = args.student_vllm_guided_decoding_regex
self.student_vllm_sync_frequency = args.student_vllm_sync_frequency
self._last_student_sync_step = -1
# Add callback to sync student model weights to vLLM after training steps
# This ensures weight syncing happens when DeepSpeed is in a stable state
self.add_callback(GKDStudentVLLMSyncCallback(self))
@staticmethod
def generalized_jsd_loss(
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
@ -187,9 +291,9 @@ class GKDTrainer(SFTTrainer):
else:
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
dim=0,
)
@ -274,6 +378,240 @@ class GKDTrainer(SFTTrainer):
return generated_tokens, new_attention_mask, new_labels
def _generate_on_policy_outputs_student_vllm(self, inputs, generation_config, pad_token_id=None):
device = self.accelerator.device
prompts_text = self.processing_class.batch_decode(
inputs["prompts"],
skip_special_tokens=True,
# clean_up_tokenization_spaces=False # Keep this commented unless specific issues arise
)
# Remove padding token text if it appears, as vLLM expects clean prompts
if self.processing_class.pad_token:
prompts_text = [p.replace(self.processing_class.pad_token, "") for p in prompts_text]
max_new_tokens = generation_config.max_new_tokens
temperature = generation_config.temperature
# vLLM uses top_k=-1 for no top_k, transformers uses 0 or None.
top_k = generation_config.top_k if generation_config.top_k and generation_config.top_k > 0 else -1
# top_p, repetition_penalty, min_p are not directly in generation_config, get from trainer args
top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0
repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0
min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0
if self.student_vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
completion_ids = self.student_vllm_client.generate(
prompts=all_prompts_text,
n=1, # In GKD, we generate 1 completion per prompt from student
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
max_tokens=max_new_tokens,
guided_decoding_regex=self.student_vllm_guided_decoding_regex,
)
else:
completion_ids = [None] * len(all_prompts_text)
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts_text),
(self.accelerator.process_index + 1) * len(prompts_text),
)
completion_ids = completion_ids[process_slice]
elif self.student_vllm_mode == "colocate":
if self.student_vllm_guided_decoding_regex:
guided_decoding = GuidedDecodingParams(
backend="outlines", regex=self.student_vllm_guided_decoding_regex
)
else:
guided_decoding = None
sampling_params = SamplingParams(
n=1,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
max_tokens=max_new_tokens,
guided_decoding=guided_decoding,
)
if hasattr(self, "student_tp_group") and self.args.student_vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
gathered_prompts = [None for _ in range(self.args.student_vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.student_tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts_text = prompts_text
all_outputs = self.student_llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False)
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
if hasattr(self, "student_tp_group") and self.args.student_vllm_tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.student_tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]
else:
raise ValueError(f"Unknown student_vllm_mode: {self.student_vllm_mode}")
# We need to combine prompt and completion for new_input_ids
# Tokenize prompts again to get prompt_ids on the correct device and format
# Ensure add_special_tokens=False as vLLM typically handles prompts as raw text
# Calculate max_length for prompts, ensuring it's positive
prompt_max_length = max(1, self.args.max_length - max_new_tokens) if self.args.max_length else None
prompt_tokenized = self.processing_class(
prompts_text,
return_tensors="pt",
padding="longest",
truncation=True if prompt_max_length else False,
max_length=prompt_max_length,
add_special_tokens=False,
).to(device)
prompt_ids = prompt_tokenized.input_ids
completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids]
# Manually pad/truncate completions to max_new_tokens length before using pad function
padded_completion_ids_list = []
for completion_tensor in completion_ids_tensors:
if len(completion_tensor) > max_new_tokens:
# Truncate if longer than max_new_tokens
padded_completion_ids_list.append(completion_tensor[:max_new_tokens])
elif len(completion_tensor) < max_new_tokens:
# Pad if shorter than max_new_tokens
padding_needed = max_new_tokens - len(completion_tensor)
padded_tensor = torch.cat(
[
completion_tensor,
torch.full((padding_needed,), pad_token_id, device=device, dtype=completion_tensor.dtype),
]
)
padded_completion_ids_list.append(padded_tensor)
else:
# Already the right length
padded_completion_ids_list.append(completion_tensor)
# Now all tensors are the same length, so we can stack them
padded_completion_ids = torch.stack(padded_completion_ids_list)
# Ensure prompt_ids and padded_completion_ids are 2D
if prompt_ids.ndim == 1:
prompt_ids = prompt_ids.unsqueeze(0)
if padded_completion_ids.ndim == 1:
padded_completion_ids = padded_completion_ids.unsqueeze(0)
new_input_ids = torch.cat([prompt_ids, padded_completion_ids], dim=1)
new_attention_mask = torch.ones_like(new_input_ids, device=device)
new_labels = new_input_ids.clone()
if pad_token_id is not None:
new_labels[new_labels == pad_token_id] = -100
new_attention_mask[new_input_ids == pad_token_id] = 0
# Mask prompt tokens in labels
prompt_lengths = prompt_ids.shape[1]
new_labels[:, :prompt_lengths] = -100
return new_input_ids, new_attention_mask, new_labels
def _sync_fsdp_params_to_student_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM."""
if visited is None:
visited = set()
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
# recurse into the child
self._sync_fsdp_params_to_student_vllm(child_module, prefix=child_prefix, visited=visited)
if isinstance(module, FSDP):
with FSDP.summon_full_params(module, recurse=False, writeback=False):
for param_name, param in module.named_parameters():
full_name = f"{prefix}.{param_name}" if prefix else param_name
for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."):
full_name = full_name.replace(extra, "")
if full_name in visited:
continue # skip FSDP subtrees already traversed
visited.add(full_name)
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
self.student_vllm_client.update_named_param(full_name, param.data)
elif self.student_vllm_mode == "colocate":
llm_model = self.student_llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])
def _move_student_model_to_vllm(self):
"""Synchronize student model weights to vLLM engine."""
# For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
if zero_stage_3:
import deepspeed
gather_if_zero3 = deepspeed.zero.GatheredParameters
else:
gather_if_zero3 = nullcontext
if is_peft_model(self.model):
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
# merging adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
self._sync_fsdp_params_to_student_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
self.student_vllm_client.update_named_param(name, param.data)
elif self.student_vllm_mode == "colocate":
llm_model = self.student_llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name, param.data)])
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
if self.is_fsdp_enabled:
# use memory-efficient post-order traversal for FSDP
self._sync_fsdp_params_to_student_vllm(self.model)
else:
# For DeepSpeed ZeRO-3, gather each parameter individually like GRPO trainer
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
self.student_vllm_client.update_named_param(name, param.data)
elif self.student_vllm_mode == "colocate":
llm_model = self.student_llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name, param.data)])
# Reset cache on vLLM
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
self.student_vllm_client.reset_prefix_cache()
elif self.student_vllm_mode == "colocate":
self.student_llm.reset_prefix_cache()
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
@ -284,19 +622,16 @@ class GKDTrainer(SFTTrainer):
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
the original inputs.
"""
if self.seq_kd:
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
if random.random() <= self.lmbda:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
if self.student_use_vllm:
new_input_ids, new_attention_mask, new_labels = self._generate_on_policy_outputs_student_vllm(
inputs, self.generation_config, self.processing_class.pad_token_id
)
else:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels