mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
26 Commits
7e9c6e45d5
...
gkd-vllm
Author | SHA1 | Date | |
---|---|---|---|
db977f63ef | |||
0e319cba26 | |||
28163312ef | |||
053edb761d | |||
e0fb8e638e | |||
88a064f797 | |||
f1f41e3195 | |||
5229d816af | |||
a8c9f8238b | |||
98eaea96b4 | |||
d16579624e | |||
37a5fcd7ca | |||
61fdb52c9c | |||
45c39c06b7 | |||
0f50383916 | |||
9427ac3379 | |||
4835a3328d | |||
382f92257f | |||
0c86e6cde6 | |||
55ec89bebe | |||
f37f0ada91 | |||
9a969d9c93 | |||
788a4a4044 | |||
d09b36a754 | |||
f665c336f2 | |||
c35c5df4ab |
@ -21,7 +21,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
|
||||
|
||||
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
|
||||
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
|
||||
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
|
||||
* `seq_kd`: controls whether to perform Sequence-Level KD which can be viewed as supervised FT on teacher-generated outputs. When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
|
||||
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
|
||||
|
||||
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
|
||||
@ -89,6 +89,17 @@ The dataset should be formatted as a list of "messages" where each message is a
|
||||
* `content`: the message content
|
||||
|
||||
|
||||
## Accelerated Generation with vLLM
|
||||
|
||||
GKD training supports accelerated student model generation using [vLLM](https://github.com/vllm-project/vllm), which can significantly speed up the on-policy data generation process. Two integration modes are available:
|
||||
|
||||
**Server Mode (`student_vllm_mode="server"`)**: In this mode, vLLM runs as a separate server process on dedicated GPUs, and the trainer communicates with it via HTTP. This approach provides good isolation between training and inference but requires additional GPU resources for the vLLM server.
|
||||
|
||||
**Co-locate Mode (`student_vllm_mode="colocate"`)**: In this mode, vLLM runs within the same distributed process group as the training job, sharing the same GPUs. This approach maximizes GPU utilization by allowing training and inference to take turns on the same hardware, eliminating idle GPU time and reducing the total number of GPUs required. Co-locate mode typically provides better throughput and is more resource-efficient.
|
||||
|
||||
To enable vLLM integration, set `student_use_vllm=True` in your [`GKDConfig`] and configure the appropriate mode. For co-locate mode, adjust `student_vllm_gpu_memory_utilization` (recommended: 0.3 for smaller models) and `student_vllm_tensor_parallel_size` based on your model size and available resources. The `student_vllm_sync_frequency` parameter controls how often the student model weights are synchronized to the vLLM engine (default: every step).
|
||||
|
||||
|
||||
## GKDTrainer
|
||||
|
||||
[[autodoc]] GKDTrainer
|
||||
|
@ -50,6 +50,26 @@ class GKDConfig(SFTConfig):
|
||||
seq_kd (`bool`, *optional*, defaults to `False`):
|
||||
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
|
||||
teacher-generated output).
|
||||
student_use_vllm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed.
|
||||
student_vllm_mode (`str`, *optional*, defaults to `"server"`):
|
||||
Mode for student vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or
|
||||
`"colocate"` (run vLLM in the same process).
|
||||
student_vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
Host of the vLLM server for the student model (if `student_vllm_mode="server"`).
|
||||
student_vllm_server_port (`int`, *optional*, defaults to `8001`):
|
||||
Port of the vLLM server for the student model (if `student_vllm_mode="server"`).
|
||||
student_vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
|
||||
Timeout for connecting to the student vLLM server (if `student_vllm_mode="server"`).
|
||||
student_vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
||||
GPU memory utilization for the colocated student vLLM engine (if `student_vllm_mode="colocate"`).
|
||||
It is recommended to set this to a low value if the student and teacher models share the same GPU.
|
||||
student_vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
|
||||
Tensor parallel size for the colocated student vLLM engine (if `student_vllm_mode="colocate"`).
|
||||
student_vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
||||
Regex for vLLM guided decoding for the student model.
|
||||
student_vllm_sync_frequency (`int`, *optional*, defaults to `1`):
|
||||
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step.
|
||||
"""
|
||||
|
||||
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
|
||||
@ -103,6 +123,54 @@ class GKDConfig(SFTConfig):
|
||||
},
|
||||
)
|
||||
|
||||
# VLLM parameters for student model
|
||||
student_use_vllm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed."
|
||||
},
|
||||
)
|
||||
student_vllm_mode: str = field(
|
||||
default="server",
|
||||
metadata={
|
||||
"help": 'Mode for student vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).'
|
||||
},
|
||||
)
|
||||
student_vllm_server_host: str = field(
|
||||
default="0.0.0.0",
|
||||
metadata={"help": 'Host of the vLLM server for the student model (if `student_vllm_mode="server"`).'},
|
||||
)
|
||||
student_vllm_server_port: int = field(
|
||||
default=8001,
|
||||
metadata={"help": 'Port of the vLLM server for the student model (if `student_vllm_mode="server"`).'},
|
||||
)
|
||||
student_vllm_server_timeout: float = field(
|
||||
default=240.0,
|
||||
metadata={"help": 'Timeout for connecting to the student vLLM server (if `student_vllm_mode="server"`).'},
|
||||
)
|
||||
student_vllm_gpu_memory_utilization: float = field(
|
||||
default=0.9,
|
||||
metadata={
|
||||
"help": 'GPU memory utilization for the colocated student vLLM engine (if `student_vllm_mode="colocate"`). It is recommended to set this to a low value if the student and teacher models share the same GPU.'
|
||||
},
|
||||
)
|
||||
student_vllm_tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": 'Tensor parallel size for the colocated student vLLM engine (if `student_vllm_mode="colocate"`).'
|
||||
},
|
||||
)
|
||||
student_vllm_guided_decoding_regex: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding for the student model."},
|
||||
)
|
||||
student_vllm_sync_frequency: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# check lmbda and beta are in the range [0, 1]
|
||||
@ -110,3 +178,10 @@ class GKDConfig(SFTConfig):
|
||||
raise ValueError("lmbda must be in the range [0.0, 1.0].")
|
||||
if self.beta < 0.0 or self.beta > 1.0:
|
||||
raise ValueError("beta must be in the range [0.0, 1.0].")
|
||||
|
||||
# Validate that max_length is sufficient for max_new_tokens
|
||||
if self.max_length is not None and self.max_new_tokens >= self.max_length:
|
||||
raise ValueError(
|
||||
f"max_new_tokens ({self.max_new_tokens}) must be smaller than max_length ({self.max_length}) "
|
||||
f"to leave room for the prompt. Consider increasing max_length or reducing max_new_tokens."
|
||||
)
|
||||
|
@ -15,12 +15,15 @@
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from accelerate.utils import broadcast_object_list, gather_object, is_peft_model
|
||||
from datasets import Dataset
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
BaseImageProcessor,
|
||||
@ -32,10 +35,12 @@ from transformers import (
|
||||
ProcessorMixin,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_vllm_available
|
||||
from ..models import prepare_deepspeed
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .gkd_config import GKDConfig
|
||||
@ -55,6 +60,33 @@ if is_peft_available():
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
class GKDStudentVLLMSyncCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to sync student model weights to vLLM after training steps.
|
||||
This ensures weight syncing happens when DeepSpeed is in a stable state.
|
||||
"""
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""Sync weights after training step when DeepSpeed is stable."""
|
||||
if (
|
||||
self.trainer.student_use_vllm
|
||||
and state.global_step != self.trainer._last_student_sync_step
|
||||
and state.global_step % self.trainer.student_vllm_sync_frequency == 0
|
||||
):
|
||||
# Check if this is a step where gradients are synchronized
|
||||
# This happens at the end of gradient accumulation cycles
|
||||
if hasattr(self.trainer.accelerator, "sync_gradients") and self.trainer.accelerator.sync_gradients:
|
||||
self.trainer._move_student_model_to_vllm()
|
||||
self.trainer._last_student_sync_step = state.global_step
|
||||
|
||||
|
||||
class GKDTrainer(SFTTrainer):
|
||||
_tag_names = ["trl", "gkd"]
|
||||
@ -80,6 +112,7 @@ class GKDTrainer(SFTTrainer):
|
||||
# add remove_unused_columns=False to the dataclass args
|
||||
args.remove_unused_columns = False
|
||||
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
|
||||
self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path
|
||||
|
||||
super().__init__(
|
||||
model,
|
||||
@ -117,6 +150,9 @@ class GKDTrainer(SFTTrainer):
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(self.model)
|
||||
|
||||
# resize the teacher's token_embeddings to the student's vocab size:
|
||||
teacher_model.resize_token_embeddings(self.model.config.vocab_size)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
|
||||
else:
|
||||
@ -145,6 +181,74 @@ class GKDTrainer(SFTTrainer):
|
||||
):
|
||||
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
# vLLM setup for student model if enabled
|
||||
self.student_use_vllm = args.student_use_vllm
|
||||
if self.student_use_vllm:
|
||||
if not is_vllm_available():
|
||||
raise ImportError(
|
||||
"vLLM is not available and student_use_vllm is set to True. Please install vLLM with "
|
||||
"`pip install vllm` to use it."
|
||||
)
|
||||
self.student_vllm_mode = args.student_vllm_mode
|
||||
if self.student_vllm_mode == "server":
|
||||
if self.accelerator.is_main_process:
|
||||
self.student_vllm_client = VLLMClient(
|
||||
host=args.student_vllm_server_host,
|
||||
server_port=args.student_vllm_server_port,
|
||||
connection_timeout=args.student_vllm_server_timeout,
|
||||
)
|
||||
self.student_vllm_client.init_communicator()
|
||||
elif self.student_vllm_mode == "colocate":
|
||||
student_model_name_or_path = self.model_name_or_path
|
||||
|
||||
# Check tensor parallel size constraints (same as GRPO)
|
||||
if args.student_vllm_tensor_parallel_size > 1:
|
||||
# Make sure tensor_parallel_size divides world size evenly
|
||||
if not self.accelerator.num_processes % args.student_vllm_tensor_parallel_size == 0:
|
||||
raise ValueError(
|
||||
f"student_vllm_tensor_parallel_size ({args.student_vllm_tensor_parallel_size}) must divide world size "
|
||||
f"({self.accelerator.num_processes}) evenly."
|
||||
)
|
||||
|
||||
# Create subgroups of ranks for TP
|
||||
self.student_tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
|
||||
[
|
||||
list(
|
||||
range(
|
||||
i * args.student_vllm_tensor_parallel_size,
|
||||
(i + 1) * args.student_vllm_tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
for i in range(self.accelerator.num_processes // args.student_vllm_tensor_parallel_size)
|
||||
]
|
||||
)
|
||||
|
||||
self.student_llm = LLM(
|
||||
model=student_model_name_or_path,
|
||||
tensor_parallel_size=args.student_vllm_tensor_parallel_size,
|
||||
gpu_memory_utilization=args.student_vllm_gpu_memory_utilization,
|
||||
# Max num seqs can be a small number as we generate one by one during training
|
||||
max_num_seqs=self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps,
|
||||
max_model_len=args.max_length, # Assuming max_length covers prompt + new tokens
|
||||
distributed_executor_backend="external_launcher",
|
||||
# Feed identical seed for tp groups to ensure sampling results are the same across workers
|
||||
seed=args.seed
|
||||
if args.student_vllm_tensor_parallel_size == 1
|
||||
else self.accelerator.process_index // args.student_vllm_tensor_parallel_size,
|
||||
)
|
||||
|
||||
# Synchronize all processes after vLLM initialization to prevent hanging
|
||||
self.accelerator.wait_for_everyone()
|
||||
else:
|
||||
raise ValueError(f"Unknown student_vllm_mode: {self.student_vllm_mode}")
|
||||
self.student_vllm_guided_decoding_regex = args.student_vllm_guided_decoding_regex
|
||||
self.student_vllm_sync_frequency = args.student_vllm_sync_frequency
|
||||
self._last_student_sync_step = -1
|
||||
|
||||
# Add callback to sync student model weights to vLLM after training steps
|
||||
# This ensures weight syncing happens when DeepSpeed is in a stable state
|
||||
self.add_callback(GKDStudentVLLMSyncCallback(self))
|
||||
|
||||
@staticmethod
|
||||
def generalized_jsd_loss(
|
||||
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
||||
@ -187,9 +291,9 @@ class GKDTrainer(SFTTrainer):
|
||||
else:
|
||||
# Compute the log of the mixture distribution
|
||||
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
|
||||
mixture_log_probs = torch.logsumexp(
|
||||
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
|
||||
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
@ -274,6 +378,240 @@ class GKDTrainer(SFTTrainer):
|
||||
|
||||
return generated_tokens, new_attention_mask, new_labels
|
||||
|
||||
def _generate_on_policy_outputs_student_vllm(self, inputs, generation_config, pad_token_id=None):
|
||||
device = self.accelerator.device
|
||||
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
inputs["prompts"],
|
||||
skip_special_tokens=True,
|
||||
# clean_up_tokenization_spaces=False # Keep this commented unless specific issues arise
|
||||
)
|
||||
# Remove padding token text if it appears, as vLLM expects clean prompts
|
||||
if self.processing_class.pad_token:
|
||||
prompts_text = [p.replace(self.processing_class.pad_token, "") for p in prompts_text]
|
||||
|
||||
max_new_tokens = generation_config.max_new_tokens
|
||||
temperature = generation_config.temperature
|
||||
# vLLM uses top_k=-1 for no top_k, transformers uses 0 or None.
|
||||
top_k = generation_config.top_k if generation_config.top_k and generation_config.top_k > 0 else -1
|
||||
# top_p, repetition_penalty, min_p are not directly in generation_config, get from trainer args
|
||||
top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0
|
||||
repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0
|
||||
min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0
|
||||
|
||||
if self.student_vllm_mode == "server":
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
completion_ids = self.student_vllm_client.generate(
|
||||
prompts=all_prompts_text,
|
||||
n=1, # In GKD, we generate 1 completion per prompt from student
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
max_tokens=max_new_tokens,
|
||||
guided_decoding_regex=self.student_vllm_guided_decoding_regex,
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text)
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts_text),
|
||||
(self.accelerator.process_index + 1) * len(prompts_text),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
elif self.student_vllm_mode == "colocate":
|
||||
if self.student_vllm_guided_decoding_regex:
|
||||
guided_decoding = GuidedDecodingParams(
|
||||
backend="outlines", regex=self.student_vllm_guided_decoding_regex
|
||||
)
|
||||
else:
|
||||
guided_decoding = None
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
max_tokens=max_new_tokens,
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
if hasattr(self, "student_tp_group") and self.args.student_vllm_tensor_parallel_size > 1:
|
||||
# Gather prompts from all ranks in the TP group and flatten.
|
||||
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
|
||||
orig_size = len(prompts_text)
|
||||
gathered_prompts = [None for _ in range(self.args.student_vllm_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.student_tp_group)
|
||||
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
|
||||
else:
|
||||
all_prompts_text = prompts_text
|
||||
|
||||
all_outputs = self.student_llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False)
|
||||
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||
|
||||
if hasattr(self, "student_tp_group") and self.args.student_vllm_tensor_parallel_size > 1:
|
||||
# Slice completions for this rank within its TP group.
|
||||
# Each rank generates all outputs — we keep only our share.
|
||||
local_rank_in_group = torch.distributed.get_rank(group=self.student_tp_group)
|
||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||
completion_ids = completion_ids[tp_slice]
|
||||
else:
|
||||
raise ValueError(f"Unknown student_vllm_mode: {self.student_vllm_mode}")
|
||||
|
||||
# We need to combine prompt and completion for new_input_ids
|
||||
# Tokenize prompts again to get prompt_ids on the correct device and format
|
||||
# Ensure add_special_tokens=False as vLLM typically handles prompts as raw text
|
||||
# Calculate max_length for prompts, ensuring it's positive
|
||||
prompt_max_length = max(1, self.args.max_length - max_new_tokens) if self.args.max_length else None
|
||||
prompt_tokenized = self.processing_class(
|
||||
prompts_text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=True if prompt_max_length else False,
|
||||
max_length=prompt_max_length,
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
prompt_ids = prompt_tokenized.input_ids
|
||||
|
||||
completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||
# Manually pad/truncate completions to max_new_tokens length before using pad function
|
||||
padded_completion_ids_list = []
|
||||
for completion_tensor in completion_ids_tensors:
|
||||
if len(completion_tensor) > max_new_tokens:
|
||||
# Truncate if longer than max_new_tokens
|
||||
padded_completion_ids_list.append(completion_tensor[:max_new_tokens])
|
||||
elif len(completion_tensor) < max_new_tokens:
|
||||
# Pad if shorter than max_new_tokens
|
||||
padding_needed = max_new_tokens - len(completion_tensor)
|
||||
padded_tensor = torch.cat(
|
||||
[
|
||||
completion_tensor,
|
||||
torch.full((padding_needed,), pad_token_id, device=device, dtype=completion_tensor.dtype),
|
||||
]
|
||||
)
|
||||
padded_completion_ids_list.append(padded_tensor)
|
||||
else:
|
||||
# Already the right length
|
||||
padded_completion_ids_list.append(completion_tensor)
|
||||
|
||||
# Now all tensors are the same length, so we can stack them
|
||||
padded_completion_ids = torch.stack(padded_completion_ids_list)
|
||||
|
||||
# Ensure prompt_ids and padded_completion_ids are 2D
|
||||
if prompt_ids.ndim == 1:
|
||||
prompt_ids = prompt_ids.unsqueeze(0)
|
||||
if padded_completion_ids.ndim == 1:
|
||||
padded_completion_ids = padded_completion_ids.unsqueeze(0)
|
||||
|
||||
new_input_ids = torch.cat([prompt_ids, padded_completion_ids], dim=1)
|
||||
|
||||
new_attention_mask = torch.ones_like(new_input_ids, device=device)
|
||||
new_labels = new_input_ids.clone()
|
||||
|
||||
if pad_token_id is not None:
|
||||
new_labels[new_labels == pad_token_id] = -100
|
||||
new_attention_mask[new_input_ids == pad_token_id] = 0
|
||||
|
||||
# Mask prompt tokens in labels
|
||||
prompt_lengths = prompt_ids.shape[1]
|
||||
new_labels[:, :prompt_lengths] = -100
|
||||
|
||||
return new_input_ids, new_attention_mask, new_labels
|
||||
|
||||
def _sync_fsdp_params_to_student_vllm(self, module: nn.Module, prefix: str = "", visited=None):
|
||||
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
for child_name, child_module in module.named_children():
|
||||
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
|
||||
# recurse into the child
|
||||
self._sync_fsdp_params_to_student_vllm(child_module, prefix=child_prefix, visited=visited)
|
||||
|
||||
if isinstance(module, FSDP):
|
||||
with FSDP.summon_full_params(module, recurse=False, writeback=False):
|
||||
for param_name, param in module.named_parameters():
|
||||
full_name = f"{prefix}.{param_name}" if prefix else param_name
|
||||
for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."):
|
||||
full_name = full_name.replace(extra, "")
|
||||
|
||||
if full_name in visited:
|
||||
continue # skip FSDP subtrees already traversed
|
||||
visited.add(full_name)
|
||||
|
||||
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
|
||||
self.student_vllm_client.update_named_param(full_name, param.data)
|
||||
elif self.student_vllm_mode == "colocate":
|
||||
llm_model = self.student_llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights([(full_name, param.data)])
|
||||
|
||||
def _move_student_model_to_vllm(self):
|
||||
"""Synchronize student model weights to vLLM engine."""
|
||||
# For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
if zero_stage_3:
|
||||
import deepspeed
|
||||
|
||||
gather_if_zero3 = deepspeed.zero.GatheredParameters
|
||||
else:
|
||||
gather_if_zero3 = nullcontext
|
||||
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
|
||||
# merging adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
# Update vLLM weights while parameters are gathered
|
||||
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
|
||||
# Update vLLM weights while parameters are gathered
|
||||
# For PEFT with FSDP we need to use the memory efficient post-order traversal
|
||||
self._sync_fsdp_params_to_student_vllm(self.model)
|
||||
else:
|
||||
# DeepSpeed ZeRO-3 with PEFT
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
|
||||
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
|
||||
self.student_vllm_client.update_named_param(name, param.data)
|
||||
elif self.student_vllm_mode == "colocate":
|
||||
llm_model = self.student_llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights([(name, param.data)])
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
|
||||
if self.is_fsdp_enabled:
|
||||
# use memory-efficient post-order traversal for FSDP
|
||||
self._sync_fsdp_params_to_student_vllm(self.model)
|
||||
else:
|
||||
# For DeepSpeed ZeRO-3, gather each parameter individually like GRPO trainer
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
|
||||
self.student_vllm_client.update_named_param(name, param.data)
|
||||
elif self.student_vllm_mode == "colocate":
|
||||
llm_model = self.student_llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights([(name, param.data)])
|
||||
|
||||
# Reset cache on vLLM
|
||||
if self.student_vllm_mode == "server" and self.accelerator.is_main_process:
|
||||
self.student_vllm_client.reset_prefix_cache()
|
||||
elif self.student_vllm_mode == "colocate":
|
||||
self.student_llm.reset_prefix_cache()
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
@ -284,19 +622,16 @@ class GKDTrainer(SFTTrainer):
|
||||
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
|
||||
the original inputs.
|
||||
"""
|
||||
if self.seq_kd:
|
||||
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
if random.random() <= self.lmbda:
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
if self.student_use_vllm:
|
||||
new_input_ids, new_attention_mask, new_labels = self._generate_on_policy_outputs_student_vllm(
|
||||
inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
else:
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
|
Reference in New Issue
Block a user