mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
13 Commits
7e9c6e45d5
...
feature/sg
Author | SHA1 | Date | |
---|---|---|---|
688b046fb8 | |||
3c2edbd808 | |||
e942b65999 | |||
b68d780617 | |||
abf4af49e9 | |||
30cf6ca9ef | |||
e5a7fd7e99 | |||
b908d3b177 | |||
183b3e0eaa | |||
0786ed66cd | |||
b71e173100 | |||
eba6b5c0af | |||
e4940d65a1 |
@ -250,13 +250,76 @@ By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM,
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
### Speed up training with SGLang-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. As an alternative to vLLM, you can use [SGLang](https://github.com/sgl-project/sglang), a high-performance inference engine designed for structured generation with language models. SGLang is particularly well-suited for tasks requiring structured outputs, complex reasoning patterns, and advanced prompting techniques. To enable it, first install the package with
|
||||
```shell
|
||||
pip install trl[sglang]
|
||||
```
|
||||
|
||||
We support two ways of using SGLang during training: **server mode** and **colocate mode**.
|
||||
|
||||
#### 🔌 Option 1: Server mode
|
||||
|
||||
In this mode, SGLang runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the SGLang server**:
|
||||
```bash
|
||||
trl sglang-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_sglang=True,
|
||||
sglang_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
In this mode, SGLang runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_sglang=True,
|
||||
sglang_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `sglang_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for SGLang, but you can override these values by setting the environment variables accordingly.
|
||||
|
||||
</Tip>
|
||||
|
||||
For more information, see [SGLang Integration](sglang_integration).
|
||||
|
||||
### GRPO at scale: train a 70B+ Model on multiple nodes
|
||||
|
||||
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
|
||||
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
|
||||
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
|
||||
- **vLLM or SGLang**: See the previous sections on how to use vLLM or SGLang to speed up generation. Both engines provide high-performance inference capabilities, with SGLang being particularly well-suited for structured generation tasks.
|
||||
|
||||
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
|
||||
|
||||
@ -323,6 +386,73 @@ if __name__=="__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
#### Alternative: Using SGLang for large-scale training
|
||||
|
||||
You can also use SGLang instead of vLLM for large-scale training. SGLang is particularly well-suited for structured generation tasks and complex reasoning patterns. Here's the same example using SGLang:
|
||||
|
||||
```sh
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=5
|
||||
#SBATCH --gres=gpu:8
|
||||
|
||||
# Get the list of allocated nodes
|
||||
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
|
||||
|
||||
# Assign the first 4 nodes for training and the 5th node for SGLang
|
||||
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
|
||||
SGLANG_NODE="${NODELIST[4]}" # Node 4 for SGLang
|
||||
|
||||
# Run training on the first 4 nodes (Group 1)
|
||||
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
--num_processes 32 \
|
||||
--num_machines 4 \
|
||||
--main_process_ip ${NODELIST[0]} \
|
||||
--machine_rank $SLURM_PROCID \
|
||||
--rdzv_backend c10d \
|
||||
train_grpo_sglang.py \
|
||||
--sglang_server_host $SGLANG_NODE &
|
||||
|
||||
# Run SGLang server on the 5th node (Group 2)
|
||||
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl sglang-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
|
||||
|
||||
wait
|
||||
```
|
||||
|
||||
```python
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sglang_server_host", type=str, default="", help="The SGLang server IP")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Example dataset from TLDR
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="Qwen2.5-72B-GRPO-SGLang",
|
||||
per_device_train_batch_size=4,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
use_sglang=True,
|
||||
sglang_server_base_url=f"http://{args.sglang_server_host}:8001",
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
|
||||
trainer.train()
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
### Using a custom reward function
|
||||
|
||||
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
|
||||
@ -551,6 +681,7 @@ Compatibility with all VLMs is not guaranteed. If you believe a model should be
|
||||
Use [grpo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified):
|
||||
|
||||
```bash
|
||||
# Using vLLM
|
||||
accelerate launch \
|
||||
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
@ -566,6 +697,23 @@ accelerate launch \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions
|
||||
|
||||
# Using SGLang (alternative for structured generation)
|
||||
accelerate launch \
|
||||
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--output_dir grpo-Qwen2.5-VL-3B-Instruct-SGLang \
|
||||
--learning_rate 1e-5 \
|
||||
--gradient_checkpointing \
|
||||
--torch_dtype bfloat16 \
|
||||
--max_prompt_length 2048 \
|
||||
--max_completion_length 1024 \
|
||||
--use_sglang \
|
||||
--sglang_mode colocate \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions
|
||||
```
|
||||
|
||||
### Configuration Tips
|
||||
@ -577,7 +725,8 @@ VLM training may fail if image tokens are truncated. We highly recommend to disa
|
||||
- Use LoRA on vision-language projection layers
|
||||
- Enable 4-bit quantization to reduce memory usage
|
||||
- VLMs are memory-intensive — start with smaller batch sizes
|
||||
- Most models are compatible with vLLM (`server` and `colocate` modes)
|
||||
- Most models are compatible with both vLLM and SGLang (`server` and `colocate` modes)
|
||||
- SGLang is particularly well-suited for structured generation tasks with VLMs
|
||||
|
||||
### Dataset Format
|
||||
|
||||
|
@ -58,12 +58,12 @@ from typing import Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor
|
||||
|
||||
import wandb
|
||||
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map
|
||||
|
||||
|
||||
|
10
setup.cfg
10
setup.cfg
@ -69,6 +69,14 @@ vllm =
|
||||
requests; python_version < "3.13"
|
||||
uvicorn; python_version < "3.13"
|
||||
|
||||
sglang =
|
||||
# SGLang package for accelerated inference
|
||||
sglang>=0.4.10
|
||||
fastapi
|
||||
pydantic
|
||||
requests
|
||||
uvicorn
|
||||
|
||||
vlm =
|
||||
Pillow
|
||||
torchvision
|
||||
@ -82,8 +90,10 @@ dev =
|
||||
%(peft)s
|
||||
%(quantization)s
|
||||
%(scikit)s
|
||||
%(sglang)s
|
||||
%(test)s
|
||||
%(vlm)s
|
||||
%(vllm)s
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
|
@ -39,7 +39,7 @@ from trl.trainer.grpo_trainer import (
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
from .testing_utils import require_vllm
|
||||
from .testing_utils import require_sglang, require_vllm
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -1274,6 +1274,198 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@require_sglang
|
||||
def test_training_sglang_colocate_mode(self):
|
||||
"""Test that training works with SGLang colocate mode for generation."""
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
|
||||
num_generations=2, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=64, # reduce prompt length to save memory
|
||||
report_to="none",
|
||||
use_sglang=True,
|
||||
sglang_mode="colocate",
|
||||
sglang_gpu_memory_utilization=0.1, # Use minimal GPU memory
|
||||
use_sglang_bucketed_updates=True, # Enable bucketed updates
|
||||
sglang_update_weight_buffer_size=64 * 1024**2, # 64MB buffer
|
||||
sglang_pause_generation_during_update=True,
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
trainer.train()
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Verify SGLang integration is working
|
||||
self.assertTrue(trainer.use_sglang, "SGLang should be enabled")
|
||||
self.assertEqual(trainer.sglang_mode, "colocate", "SGLang should be in colocate mode")
|
||||
self.assertIsNotNone(trainer.sglang_engine, "SGLang engine should be initialized")
|
||||
|
||||
# Verify slime-style bucketed weight updater is enabled
|
||||
self.assertTrue(trainer.args.use_sglang_bucketed_updates, "Bucketed updates should be enabled")
|
||||
if hasattr(trainer, "sglang_weight_updater"):
|
||||
self.assertIsNotNone(trainer.sglang_weight_updater, "SGLang weight updater should be initialized")
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@require_sglang
|
||||
def test_training_sglang_colocate_with_peft(self):
|
||||
"""Test that training works with SGLang colocate mode and PEFT."""
|
||||
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
|
||||
num_generations=2, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=64, # reduce prompt length to save memory
|
||||
report_to="none",
|
||||
use_sglang=True,
|
||||
sglang_mode="colocate",
|
||||
sglang_gpu_memory_utilization=0.1, # Use minimal GPU memory
|
||||
use_sglang_bucketed_updates=False, # Disable bucketed updates for now
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["lm_head"], # Simpler config for testing
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
trainer.train()
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Verify SGLang with PEFT integration
|
||||
self.assertTrue(trainer.use_sglang, "SGLang should be enabled")
|
||||
self.assertEqual(trainer.sglang_mode, "colocate", "SGLang should be in colocate mode")
|
||||
|
||||
# Check that the peft params have changed and the base model params have not changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
if n in base_param_names: # We expect the base model params to be the same
|
||||
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
|
||||
elif "base_layer" not in n and "original_module" not in n:
|
||||
# We expect the peft params to be different (except for the base layer)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@require_sglang
|
||||
def test_training_sglang_bucketed_weight_updates(self):
|
||||
"""Test SGLang with sophisticated slime-style bucketed weight updates."""
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
|
||||
num_generations=2, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=64, # reduce prompt length to save memory
|
||||
report_to="none",
|
||||
use_sglang=True,
|
||||
sglang_mode="colocate",
|
||||
sglang_gpu_memory_utilization=0.05, # Use minimal GPU memory for testing
|
||||
# Enable ALL sophisticated slime-style features
|
||||
use_sglang_bucketed_updates=True, # Core bucketed updates
|
||||
sglang_update_weight_buffer_size=32 * 1024**2, # 32MB buffer for testing
|
||||
sglang_pause_generation_during_update=True, # Pause/resume during updates
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# Verify sophisticated weight updater is properly initialized
|
||||
self.assertTrue(trainer.use_sglang, "SGLang should be enabled")
|
||||
self.assertEqual(trainer.sglang_mode, "colocate", "SGLang should be in colocate mode")
|
||||
self.assertIsNotNone(trainer.sglang_engine, "SGLang engine should be initialized")
|
||||
|
||||
# Verify slime-style bucketed weight updater
|
||||
self.assertTrue(trainer.args.use_sglang_bucketed_updates, "Bucketed updates should be enabled")
|
||||
self.assertIsNotNone(trainer.sglang_weight_updater, "SGLang weight updater should be initialized")
|
||||
self.assertEqual(trainer.args.sglang_update_weight_buffer_size, 32 * 1024**2, "Buffer size should be 32MB")
|
||||
self.assertTrue(trainer.args.sglang_pause_generation_during_update, "Generation pause should be enabled")
|
||||
|
||||
# Test memory info functionality
|
||||
memory_info = trainer.sglang_weight_updater.get_memory_info()
|
||||
self.assertIn("total_GB", memory_info, "Memory info should contain total GB")
|
||||
self.assertIn("free_GB", memory_info, "Memory info should contain free GB")
|
||||
self.assertIn("used_GB", memory_info, "Memory info should contain used GB")
|
||||
|
||||
# Test parameter info extraction
|
||||
param_infos = trainer.sglang_weight_updater.get_param_infos(trainer.model)
|
||||
self.assertGreater(len(param_infos), 0, "Should extract parameter information")
|
||||
|
||||
# Verify parameter info structure
|
||||
for param_info in param_infos[:3]: # Check first few parameters
|
||||
self.assertIsInstance(param_info.name, str, "Parameter name should be string")
|
||||
self.assertIsInstance(param_info.size, int, "Parameter size should be integer")
|
||||
self.assertGreater(param_info.size, 0, "Parameter size should be positive")
|
||||
|
||||
# Test bucketing functionality
|
||||
param_buckets = trainer.sglang_weight_updater.get_param_info_buckets(param_infos)
|
||||
self.assertGreater(len(param_buckets), 0, "Should create parameter buckets")
|
||||
self.assertIsInstance(param_buckets, list, "Buckets should be a list")
|
||||
|
||||
# Verify buckets respect memory constraints
|
||||
for bucket in param_buckets:
|
||||
bucket_size = sum(p.size for p in bucket)
|
||||
# Allow single oversized parameters to be in their own bucket
|
||||
if len(bucket) == 1:
|
||||
# Single parameter bucket can exceed buffer size (oversized parameter)
|
||||
continue
|
||||
else:
|
||||
# Multi-parameter buckets must respect buffer size
|
||||
self.assertLessEqual(
|
||||
bucket_size,
|
||||
trainer.args.sglang_update_weight_buffer_size,
|
||||
f"Multi-parameter bucket size ({bucket_size}) should not exceed buffer size ({trainer.args.sglang_update_weight_buffer_size})",
|
||||
)
|
||||
|
||||
# Test sophisticated weight update
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# This should trigger the sophisticated slime-style bucketed weight update
|
||||
trainer._move_model_to_sglang()
|
||||
|
||||
# Verify parameters changed (indicating successful weight update)
|
||||
params_changed = 0
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
if not torch.equal(param, new_param):
|
||||
params_changed += 1
|
||||
|
||||
# Note: Parameters might not change in a simple weight sync, but the operation should complete successfully
|
||||
# The key test is that the sophisticated weight update completes without errors
|
||||
|
||||
# Test that we can do multiple weight updates (stress test)
|
||||
for i in range(3):
|
||||
try:
|
||||
trainer._move_model_to_sglang()
|
||||
# Each update should complete successfully
|
||||
except Exception as e:
|
||||
self.fail(f"Weight update {i + 1} failed: {e}")
|
||||
|
||||
def test_training_no_scale_rewards(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
|
@ -26,6 +26,7 @@ from trl.import_utils import (
|
||||
is_joblib_available,
|
||||
is_llm_blender_available,
|
||||
is_mergekit_available,
|
||||
is_sglang_available,
|
||||
is_vllm_available,
|
||||
)
|
||||
|
||||
@ -81,6 +82,13 @@ def require_sklearn(test_case):
|
||||
return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case)
|
||||
|
||||
|
||||
def require_sglang(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires sglang. Skips the test if sglang is not available.
|
||||
"""
|
||||
return unittest.skipUnless(is_sglang_available(), "test requires sglang")(test_case)
|
||||
|
||||
|
||||
def require_vllm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires vllm. Skips the test if vllm is not available.
|
||||
|
17
trl/cli.py
17
trl/cli.py
@ -24,6 +24,8 @@ from .scripts.env import print_env
|
||||
from .scripts.grpo import make_parser as make_grpo_parser
|
||||
from .scripts.kto import make_parser as make_kto_parser
|
||||
from .scripts.sft import make_parser as make_sft_parser
|
||||
from .scripts.sglang_serve import main as sglang_serve_main
|
||||
from .scripts.sglang_serve import make_parser as make_sglang_serve_parser
|
||||
from .scripts.utils import TrlParser
|
||||
from .scripts.vllm_serve import main as vllm_serve_main
|
||||
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser
|
||||
@ -41,6 +43,7 @@ def main():
|
||||
make_grpo_parser(subparsers)
|
||||
make_kto_parser(subparsers)
|
||||
make_sft_parser(subparsers)
|
||||
make_sglang_serve_parser(subparsers)
|
||||
make_vllm_serve_parser(subparsers)
|
||||
|
||||
# Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
|
||||
@ -132,6 +135,20 @@ def main():
|
||||
|
||||
vllm_serve_main(script_args)
|
||||
|
||||
elif args.command == "sglang-serve":
|
||||
(script_args,) = parser.parse_args_and_config()
|
||||
|
||||
# Similar warning for SGLang if needed
|
||||
if script_args.tensor_parallel_size == 1 and script_args.data_parallel_size > 1:
|
||||
warnings.warn(
|
||||
"Detected configuration: tensor_parallel_size=1 and data_parallel_size>1. This setup may "
|
||||
"cause issues when using the `trl sglang-serve` CLI entry point. If you encounter issues, "
|
||||
"please run the server using the module path instead: `python -m trl.scripts.sglang_serve`",
|
||||
RuntimeWarning,
|
||||
)
|
||||
|
||||
sglang_serve_main(script_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
467
trl/extras/sglang_client.py
Normal file
467
trl/extras/sglang_client.py
Normal file
@ -0,0 +1,467 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..import_utils import is_requests_available, is_sglang_available
|
||||
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
from requests import ConnectionError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SGLangClient:
|
||||
"""
|
||||
A client class to interact with an SGLang server.
|
||||
|
||||
This class provides methods to generate completions, initialize and manage weight update groups, and update model
|
||||
weights in a distributed setting. Before using it, start the SGLang server with `trl sglang-serve`.
|
||||
|
||||
Args:
|
||||
base_url (`str` or `None`, *optional*, defaults to `None`):
|
||||
Base URL for the SGLang server (e.g., `"http://localhost:8001"`). If provided, `host` and `server_port` are
|
||||
ignored.
|
||||
host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
IP address of the SGLang server. Ignored if `base_url` is provided.
|
||||
server_port (`int`, *optional*, defaults to `8001`):
|
||||
Port number of the SGLang server. Ignored if `base_url` is provided.
|
||||
group_port (`int`, *optional*, defaults to `51217`):
|
||||
Port number for the weight update group.
|
||||
connection_timeout (`float`, *optional*, defaults to `0.0`):
|
||||
Total timeout duration in seconds to wait for the server to be up.
|
||||
|
||||
Examples:
|
||||
Run the SGLang server with the model `Qwen/Qwen2.5-7B`:
|
||||
|
||||
```
|
||||
$ trl sglang-serve --model Qwen/Qwen2.5-7B
|
||||
...
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8001 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
Use the client to generate completions and update model weights:
|
||||
|
||||
```python
|
||||
>>> from trl.extras.sglang_client import SGLangClient
|
||||
|
||||
>>> client = SGLangClient()
|
||||
>>> client.generate(["Hello, AI!", "Tell me a joke"])
|
||||
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
|
||||
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
|
||||
|
||||
>>> from transformers import AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
|
||||
>>> client.init_communicator(device="cuda")
|
||||
>>> client.update_model_params(model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
host: str = "0.0.0.0",
|
||||
server_port: int = 8001,
|
||||
group_port: int = 51217,
|
||||
connection_timeout: float = 0.0,
|
||||
):
|
||||
if not is_requests_available():
|
||||
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
|
||||
if not is_sglang_available():
|
||||
raise ImportError("SGLang is not installed. Please install it with `pip install sglang`.")
|
||||
|
||||
self.session = requests.Session()
|
||||
|
||||
if base_url is not None:
|
||||
# Parse the base_url to extract host and port
|
||||
parsed_url = urlparse(base_url)
|
||||
self.host = socket.gethostbyname(parsed_url.hostname)
|
||||
scheme = parsed_url.scheme or "http"
|
||||
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
|
||||
else:
|
||||
self.host = host
|
||||
self.server_port = server_port
|
||||
self.base_url = f"http://{self.host}:{self.server_port}"
|
||||
self.group_port = group_port
|
||||
self.check_server(connection_timeout) # check server and fail after timeout
|
||||
|
||||
# Initialize communicator-related attributes
|
||||
self.pynccl_comm = None
|
||||
self.rank = None
|
||||
self.world_size = None
|
||||
|
||||
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
|
||||
"""
|
||||
Check server availability with retries on failure.
|
||||
|
||||
Args:
|
||||
retry_interval (`float`, *optional*, defaults to `2.0`):
|
||||
Interval in seconds between retries.
|
||||
total_timeout (`float`, *optional*, defaults to `0.0`):
|
||||
Total timeout duration in seconds.
|
||||
"""
|
||||
url = f"{self.base_url}/health/"
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
except requests.exceptions.RequestException as exc:
|
||||
# Check if the total timeout duration has passed
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= total_timeout:
|
||||
raise ConnectionError(
|
||||
f"The SGLang server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
|
||||
"sure the server is running by running `trl sglang-serve`."
|
||||
) from exc
|
||||
else:
|
||||
if response.status_code == 200:
|
||||
if "X-Forwarded-For" in response.headers:
|
||||
self.host = response.headers["X-Forwarded-For"]
|
||||
logger.info("Server is up!")
|
||||
return None
|
||||
|
||||
# Retry logic: wait before trying again
|
||||
logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...")
|
||||
time.sleep(retry_interval)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[list] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
max_tokens: int = 16,
|
||||
sampling_params: Optional[dict] = None,
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Generates model completions for the provided prompts.
|
||||
|
||||
Args:
|
||||
prompts (`list[str]`):
|
||||
List of text prompts for which the model will generate completions.
|
||||
images (`list[PIL.Image]` or `None`, *optional*, defaults to `None`):
|
||||
List of PIL Images to send along with the prompts.
|
||||
temperature (`float`, *optional*, defaults to `1.0`):
|
||||
Temperature parameter for sampling.
|
||||
top_p (`float`, *optional*, defaults to `1.0`):
|
||||
Top-p sampling parameter.
|
||||
top_k (`int`, *optional*, defaults to `-1`):
|
||||
Top-k sampling parameter.
|
||||
max_tokens (`int`, *optional*, defaults to `16`):
|
||||
Maximum number of tokens to generate.
|
||||
sampling_params (`dict` or `None`, *optional*, defaults to `None`):
|
||||
Additional sampling parameters for SGLang.
|
||||
|
||||
Returns:
|
||||
`list[list[int]]`:
|
||||
List of lists of token IDs representing the model-generated completions.
|
||||
"""
|
||||
url = f"{self.base_url}/generate/"
|
||||
|
||||
def pil_to_base64(image):
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
img_bytes = buffer.getvalue()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
# Convert PIL images to base64 strings
|
||||
images = [pil_to_base64(img) for img in images] if images else None
|
||||
|
||||
# Prepare sampling parameters
|
||||
params = sampling_params or {}
|
||||
# Update with provided parameters, ensuring correct key names
|
||||
params.update(
|
||||
{
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"max_new_tokens": params.get(
|
||||
"max_new_tokens", max_tokens
|
||||
), # Support both max_tokens and max_new_tokens
|
||||
}
|
||||
)
|
||||
# Remove max_tokens if max_new_tokens is present to avoid conflicts
|
||||
if "max_tokens" in params and "max_new_tokens" in params:
|
||||
params.pop("max_tokens")
|
||||
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
"prompts": prompts,
|
||||
"images": images,
|
||||
"sampling_params": params,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()["completion_ids"]
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
def init_communicator(self, device: Union[torch.device, str, int] = 0):
|
||||
"""
|
||||
Initializes the weight update group in a distributed setup for model synchronization.
|
||||
|
||||
Args:
|
||||
device (`torch.device`, `str`, or `int`, *optional*, defaults to `0`):
|
||||
Device of trainer main process.
|
||||
"""
|
||||
# Get the world size from the server
|
||||
url = f"{self.base_url}/get_world_size/"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
sglang_world_size = response.json()["world_size"]
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
world_size = sglang_world_size + 1 # add the client to the world
|
||||
self.rank = sglang_world_size # the client's rank is the last process
|
||||
self.world_size = world_size
|
||||
|
||||
# Initialize weight update group
|
||||
url = f"{self.base_url}/init_communicator/"
|
||||
client_device_uuid = str(torch.cuda.get_device_properties(device).uuid)
|
||||
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
"host": "0.0.0.0",
|
||||
"port": self.group_port,
|
||||
"world_size": world_size,
|
||||
"client_device_uuid": client_device_uuid,
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
# Brief delay to allow server initialization
|
||||
time.sleep(0.1)
|
||||
|
||||
# Set up the communication group for weight broadcasting
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=f"tcp://{self.host}:{self.group_port}",
|
||||
rank=self.rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# When the client object is deleted, close the weight update group
|
||||
atexit.register(self.close_communicator)
|
||||
|
||||
def update_named_param(self, name: str, weights: torch.Tensor):
|
||||
"""
|
||||
Updates a specific named parameter in the model and broadcasts it to other processes.
|
||||
Uses SGLang's native weight update mechanism for efficiency.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
Name of the layer whose weights are being updated.
|
||||
weights (`torch.Tensor`):
|
||||
Tensor containing the updated weights.
|
||||
"""
|
||||
dtype_str = str(weights.dtype)
|
||||
shape = list(weights.shape)
|
||||
|
||||
# Use SGLang's update_weights_from_distributed endpoint
|
||||
url = f"{self.base_url}/update_weights/"
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": [name],
|
||||
"dtypes": [dtype_str],
|
||||
"shapes": [shape],
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": True,
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
# Broadcast the weights to the other processes using NCCL
|
||||
import torch.distributed as dist
|
||||
|
||||
dist.broadcast(weights, src=self.rank)
|
||||
dist.barrier()
|
||||
|
||||
def update_model_params(self, model: nn.Module):
|
||||
"""
|
||||
Updates all parameters of the given model.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
Model whose parameters are to be updated.
|
||||
"""
|
||||
# Batch all parameter updates
|
||||
names = []
|
||||
dtypes = []
|
||||
shapes = []
|
||||
weights_list = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
names.append(name)
|
||||
dtypes.append(str(param.data.dtype))
|
||||
shapes.append(list(param.data.shape))
|
||||
weights_list.append(param.data)
|
||||
|
||||
# Send metadata to server using SGLang's batch update API
|
||||
url = f"{self.base_url}/update_weights/"
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": True,
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
# Broadcast all weights
|
||||
import torch.distributed as dist
|
||||
|
||||
for weight in weights_list:
|
||||
dist.broadcast(weight, src=self.rank)
|
||||
dist.barrier()
|
||||
|
||||
def update_weights_bucketed(
|
||||
self,
|
||||
names: list[str],
|
||||
dtypes: list[str],
|
||||
shapes: list[list[int]],
|
||||
group_name: str = "weight_sync",
|
||||
flush_cache: bool = False,
|
||||
):
|
||||
"""
|
||||
Updates model weights using bucketed batch approach (slime-style).
|
||||
|
||||
Args:
|
||||
names (`list[str]`):
|
||||
List of parameter names to update.
|
||||
dtypes (`list[str]`):
|
||||
List of parameter data types.
|
||||
shapes (`list[list[int]]`):
|
||||
List of parameter shapes.
|
||||
group_name (`str`, *optional*, defaults to `"weight_sync"`):
|
||||
Name of the distributed group for weight synchronization.
|
||||
flush_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flush the cache after this bucket update.
|
||||
"""
|
||||
# Send metadata to server using SGLang's batch update API
|
||||
url = f"{self.base_url}/update_weights/"
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": group_name,
|
||||
"flush_cache": flush_cache,
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"SGLang bucketed weight update failed: {response.status_code}, {response.text}")
|
||||
|
||||
def get_memory_info(self):
|
||||
"""
|
||||
Get memory information from the SGLang server.
|
||||
|
||||
Returns:
|
||||
dict: Memory information from the server.
|
||||
"""
|
||||
url = f"{self.base_url}/get_memory_info/"
|
||||
try:
|
||||
response = self.session.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning(f"Failed to get memory info: {response.status_code}")
|
||||
return {"error": "Unable to get memory info"}
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception getting memory info: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def pause_generation(self):
|
||||
"""Pause generation on the SGLang server."""
|
||||
url = f"{self.base_url}/pause_generation/"
|
||||
response = self.session.post(url)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to pause generation: {response.status_code}, {response.text}")
|
||||
|
||||
def continue_generation(self):
|
||||
"""Continue generation on the SGLang server."""
|
||||
url = f"{self.base_url}/continue_generation/"
|
||||
response = self.session.post(url)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to continue generation: {response.status_code}, {response.text}")
|
||||
|
||||
def flush_cache(self):
|
||||
"""
|
||||
Flush the cache for the model.
|
||||
"""
|
||||
url = f"{self.base_url}/flush_cache/"
|
||||
response = self.session.post(url)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
def close_communicator(self):
|
||||
"""
|
||||
Closes the weight update group and cleans up the communication group.
|
||||
"""
|
||||
url = f"{self.base_url}/close_communicator/"
|
||||
|
||||
try:
|
||||
response = self.session.post(url)
|
||||
except ConnectionError:
|
||||
# The server might be already down
|
||||
pass
|
||||
else:
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Failed to close communicator: {response.status_code}, {response.text}")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
client = SGLangClient()
|
||||
client.init_communicator(device="cuda")
|
||||
|
||||
# Generate completions
|
||||
responses = client.generate(["Hello, AI!", "Tell me a joke"], max_tokens=32)
|
||||
# Example output would show responses here
|
||||
|
||||
# Update model weights
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda")
|
||||
client.update_model_params(model)
|
411
trl/extras/sglang_engine_adapter.py
Normal file
411
trl/extras/sglang_engine_adapter.py
Normal file
@ -0,0 +1,411 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SGLang Engine Adapter - Improved implementation based on slime patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from urllib3.exceptions import NewConnectionError
|
||||
|
||||
from ..import_utils import is_sglang_available
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if is_sglang_available():
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
|
||||
def get_base_gpu_id(args, rank):
|
||||
"""
|
||||
Calculate base GPU ID for SGLang engine based on slime's logic.
|
||||
|
||||
Args:
|
||||
args: Configuration arguments
|
||||
rank: Rank of the current engine
|
||||
|
||||
Returns:
|
||||
int: Base GPU ID to use
|
||||
"""
|
||||
num_gpus = min(getattr(args, "sglang_num_gpus_per_node", 8), getattr(args, "sglang_num_gpus_per_engine", 1))
|
||||
|
||||
if getattr(args, "colocate", True):
|
||||
start_index = (rank * num_gpus) % getattr(args, "sglang_num_gpus_per_node", 8)
|
||||
else:
|
||||
num_actor_gpus = getattr(args, "actor_num_gpus_per_node", 0) * getattr(args, "actor_num_nodes", 1)
|
||||
start_index = (num_actor_gpus + rank * num_gpus) % getattr(args, "sglang_num_gpus_per_node", 8)
|
||||
|
||||
return start_index
|
||||
|
||||
|
||||
def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
|
||||
"""
|
||||
Launch SGLang server in a separate process with proper health checking.
|
||||
Based on slime's implementation.
|
||||
"""
|
||||
p = multiprocessing.Process(target=launch_server, args=(server_args,))
|
||||
p.start()
|
||||
|
||||
if server_args.node_rank != 0:
|
||||
return p
|
||||
|
||||
base_url = server_args.url()
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {server_args.api_key}" if server_args.api_key else "",
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
# Wait for server to be ready
|
||||
while True:
|
||||
try:
|
||||
response = session.get(f"{base_url}/health_generate", headers=headers)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.RequestException:
|
||||
pass
|
||||
|
||||
if not p.is_alive():
|
||||
raise Exception("SGLang server process terminated unexpectedly.")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
# Ensure working queue is empty for offload support
|
||||
while True:
|
||||
try:
|
||||
response = session.get(f"{base_url}/flush_cache", headers=headers)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.RequestException:
|
||||
pass
|
||||
|
||||
if not p.is_alive():
|
||||
raise Exception("SGLang server process terminated unexpectedly.")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
class SGLangHttpServerEngineAdapter:
|
||||
"""
|
||||
SGLang HTTP Server Engine Adapter based on slime's HttpServerEngineAdapter.
|
||||
|
||||
This class provides a clean interface to launch and manage SGLang HTTP servers
|
||||
with proper weight synchronization, memory management, and distributed support.
|
||||
"""
|
||||
|
||||
def __init__(self, router_ip=None, router_port=None, **kwargs):
|
||||
self.router_ip = router_ip
|
||||
self.router_port = router_port
|
||||
self.server_args = ServerArgs(**kwargs)
|
||||
self.node_rank = self.server_args.node_rank
|
||||
|
||||
logger.info(f"Launch SGLangHttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}")
|
||||
|
||||
# Launch server process
|
||||
self.process = launch_server_process(self.server_args)
|
||||
|
||||
# Register with router if specified
|
||||
if self.node_rank == 0 and self.router_ip and self.router_port:
|
||||
try:
|
||||
requests.post(
|
||||
f"http://{self.router_ip}:{self.router_port}/add_worker"
|
||||
f"?url=http://{self.server_args.host}:{self.server_args.port}"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Failed to register with router: {e}")
|
||||
|
||||
def _make_request(self, endpoint: str, payload: Optional[dict] = None, method: str = "POST"):
|
||||
"""
|
||||
Make a request to the SGLang server.
|
||||
|
||||
Args:
|
||||
endpoint: The API endpoint to call
|
||||
payload: The JSON payload to send (default: empty dict)
|
||||
method: HTTP method (GET or POST)
|
||||
|
||||
Returns:
|
||||
The JSON response from the server
|
||||
"""
|
||||
if self.node_rank != 0:
|
||||
return
|
||||
|
||||
url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(url)
|
||||
else:
|
||||
response = requests.post(url, json=payload or {})
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
serialized_named_tensors: list[str],
|
||||
load_format: Optional[str] = None,
|
||||
flush_cache: bool = False,
|
||||
):
|
||||
"""
|
||||
Update model weights from tensor data using SGLang's native API.
|
||||
|
||||
This method uses SGLang's built-in weight update mechanism for efficient
|
||||
GPU-to-GPU weight transfer without CPU intermediary.
|
||||
"""
|
||||
return self._make_request(
|
||||
"update_weights_from_tensor",
|
||||
{
|
||||
"serialized_named_tensors": serialized_named_tensors,
|
||||
"load_format": load_format,
|
||||
"flush_cache": flush_cache,
|
||||
},
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(self, names, dtypes, shapes, group_name, flush_cache=False):
|
||||
"""
|
||||
Update model weights from distributed training using NCCL broadcast.
|
||||
|
||||
This is the preferred method for weight synchronization in distributed setups.
|
||||
"""
|
||||
return self._make_request(
|
||||
"update_weights_from_distributed",
|
||||
{
|
||||
"names": names,
|
||||
"dtypes": [str(dtype).replace("torch.", "") for dtype in dtypes],
|
||||
"shapes": shapes,
|
||||
"group_name": group_name,
|
||||
"flush_cache": flush_cache,
|
||||
},
|
||||
)
|
||||
|
||||
def update_weights_bucketed(self, names, dtypes, shapes, group_name, flush_cache=False):
|
||||
"""
|
||||
Update model weights using bucketed batch approach (slime-style).
|
||||
|
||||
This method is optimized for large models with memory-aware parameter bucketing.
|
||||
"""
|
||||
return self._make_request(
|
||||
"update_weights_from_distributed",
|
||||
{
|
||||
"names": names,
|
||||
"dtypes": [str(dtype).replace("torch.", "") for dtype in dtypes],
|
||||
"shapes": shapes,
|
||||
"group_name": group_name,
|
||||
"flush_cache": flush_cache,
|
||||
},
|
||||
)
|
||||
|
||||
def pause_generation(self):
|
||||
"""Pause generation on the server."""
|
||||
return self._make_request("pause_generation")
|
||||
|
||||
def continue_generation(self):
|
||||
"""Continue generation on the server."""
|
||||
return self._make_request("continue_generation")
|
||||
|
||||
def get_memory_info(self):
|
||||
"""Get memory information from the server."""
|
||||
return self._make_request("get_memory_info", method="GET")
|
||||
|
||||
def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
|
||||
"""
|
||||
Initialize the distributed weight update group.
|
||||
"""
|
||||
return self._make_request(
|
||||
"init_weights_update_group",
|
||||
{
|
||||
"master_address": master_address,
|
||||
"master_port": master_port,
|
||||
"rank_offset": rank_offset,
|
||||
"world_size": world_size,
|
||||
"group_name": group_name,
|
||||
"backend": backend,
|
||||
},
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the cache of the server."""
|
||||
if self.node_rank != 0:
|
||||
return
|
||||
|
||||
# flush_cache will not return status_code 200 when there are pending requests
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"http://{self.server_args.host}:{self.server_args.port}/flush_cache")
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except NewConnectionError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error flushing cache: {e}")
|
||||
continue
|
||||
|
||||
def release_memory_occupation(self):
|
||||
"""Release memory occupation for offloading support."""
|
||||
return self._make_request("release_memory_occupation")
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
"""Resume memory occupation after offloading."""
|
||||
return self._make_request("resume_memory_occupation")
|
||||
|
||||
def generate(self, prompts, sampling_params, images=None):
|
||||
"""
|
||||
Generate completions using the SGLang server.
|
||||
|
||||
Args:
|
||||
prompts: List of text prompts
|
||||
sampling_params: Dictionary of sampling parameters
|
||||
images: Optional list of images for multi-modal generation
|
||||
|
||||
Returns:
|
||||
Generated completions
|
||||
"""
|
||||
payload = {
|
||||
"text": prompts,
|
||||
"sampling_params": sampling_params,
|
||||
}
|
||||
|
||||
if images:
|
||||
payload["images"] = images
|
||||
|
||||
return self._make_request("generate", payload)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the server and clean up resources."""
|
||||
# Deregister from router
|
||||
if self.router_ip and self.router_port:
|
||||
try:
|
||||
requests.post(
|
||||
f"http://{self.router_ip}:{self.router_port}/remove_worker"
|
||||
f"?url=http://{self.server_args.host}:{self.server_args.port}"
|
||||
)
|
||||
except requests.RequestException:
|
||||
pass # Router might be down
|
||||
|
||||
# Kill the server process
|
||||
kill_process_tree(self.process.pid)
|
||||
|
||||
|
||||
class SGLangEngine:
|
||||
"""
|
||||
SGLang Engine wrapper based on slime's SglangEngine.
|
||||
|
||||
This class provides a higher-level interface for managing SGLang engines
|
||||
with proper resource management and distributed support.
|
||||
"""
|
||||
|
||||
def __init__(self, args, rank, dist_init_addr, port, nccl_port):
|
||||
self.args = args
|
||||
self.rank = rank
|
||||
|
||||
# Remove CUDA_VISIBLE_DEVICES set by ray/accelerate and use base_gpu_id
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
# Calculate distributed configuration
|
||||
nnodes = max(
|
||||
1, getattr(args, "sglang_tensor_parallel_size", 1) // getattr(args, "sglang_num_gpus_per_node", 8)
|
||||
)
|
||||
node_rank = rank % nnodes
|
||||
|
||||
# Prepare server configuration
|
||||
server_kwargs = {
|
||||
"model_path": args.model if hasattr(args, "model") else args.sglang_model_path,
|
||||
"trust_remote_code": getattr(args, "trust_remote_code", True),
|
||||
"random_seed": getattr(args, "seed", 42) + rank,
|
||||
# Memory configuration
|
||||
"enable_memory_saver": getattr(args, "offload", False),
|
||||
"mem_fraction_static": getattr(args, "sglang_gpu_memory_utilization", 0.2),
|
||||
# Distributed configuration
|
||||
"host": getattr(args, "sglang_host", "0.0.0.0"),
|
||||
"port": port,
|
||||
"nccl_port": nccl_port,
|
||||
"nnodes": nnodes,
|
||||
"node_rank": node_rank,
|
||||
"dist_init_addr": dist_init_addr,
|
||||
"gpu_id_step": 1,
|
||||
"base_gpu_id": get_base_gpu_id(args, rank),
|
||||
# Parallelism configuration
|
||||
"tp_size": getattr(args, "sglang_tensor_parallel_size", 1),
|
||||
"dp_size": getattr(args, "sglang_data_parallel_size", 1),
|
||||
"pp_size": getattr(args, "sglang_pipeline_parallel_size", 1),
|
||||
"ep_size": getattr(args, "sglang_expert_parallel_size", 1),
|
||||
# Performance configuration
|
||||
"skip_server_warmup": True, # Always skip warmup to prevent timeout
|
||||
}
|
||||
|
||||
# Filter out None values and unsupported arguments
|
||||
server_kwargs = {k: v for k, v in server_kwargs.items() if v is not None}
|
||||
|
||||
# Create the HTTP server engine adapter
|
||||
self.llm = SGLangHttpServerEngineAdapter(
|
||||
router_ip=getattr(args, "sglang_router_ip", None),
|
||||
router_port=getattr(args, "sglang_router_port", None),
|
||||
**server_kwargs,
|
||||
)
|
||||
|
||||
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
|
||||
"""Initialize the distributed process group for weight updates."""
|
||||
return self.llm.init_weights_update_group(
|
||||
master_address, master_port, rank_offset, world_size, group_name, backend
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
||||
"""Update weights from distributed training."""
|
||||
self.llm.update_weights_from_distributed(names, dtypes, shapes, group_name)
|
||||
return
|
||||
|
||||
def update_weights_from_tensor(self, ipc_handles):
|
||||
"""Update weights from tensor handles."""
|
||||
self.llm.update_weights_from_tensor(ipc_handles)
|
||||
return
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
"""Reset the prefix cache."""
|
||||
self.llm.flush_cache()
|
||||
|
||||
def sleep(self, level=1):
|
||||
"""Release memory occupation for offloading."""
|
||||
self.llm.flush_cache()
|
||||
self.llm.release_memory_occupation()
|
||||
|
||||
def wake_up(self):
|
||||
"""Resume memory occupation after offloading."""
|
||||
self.llm.resume_memory_occupation()
|
||||
|
||||
def pause_generation(self):
|
||||
"""Pause generation."""
|
||||
self.llm.pause_generation()
|
||||
|
||||
def continue_generation(self):
|
||||
"""Continue generation."""
|
||||
self.llm.continue_generation()
|
||||
|
||||
def generate(self, prompts, sampling_params, images=None):
|
||||
"""Generate completions."""
|
||||
return self.llm.generate(prompts, sampling_params, images)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the engine."""
|
||||
self.llm.shutdown()
|
434
trl/extras/sglang_weight_utils.py
Normal file
434
trl/extras/sglang_weight_utils.py
Normal file
@ -0,0 +1,434 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SGLang Weight Update Utilities - Slime-style batched weight updates with bucketing and memory management.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import is_peft_model
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamInfo:
|
||||
"""Parameter information for distributed weight updates."""
|
||||
|
||||
name: str
|
||||
dtype: torch.dtype
|
||||
shape: torch.Size
|
||||
attrs: dict # Contains tensor parallelism and other attributes
|
||||
size: int # Parameter size in bytes
|
||||
src_rank: int # Source rank that owns this parameter
|
||||
|
||||
|
||||
class SGLangWeightUpdater:
|
||||
"""
|
||||
Slime-style batched weight updater for SGLang engines with memory management and bucketing.
|
||||
|
||||
This class implements sophisticated parameter bucketing and distributed weight updates
|
||||
following the patterns from the slime framework for optimal performance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
sglang_mode: str,
|
||||
sglang_client=None,
|
||||
sglang_engine=None,
|
||||
accelerator=None,
|
||||
update_weight_buffer_size: int = 512 * 1024**2, # 512MB default
|
||||
):
|
||||
self.model = model
|
||||
self.sglang_mode = sglang_mode
|
||||
self.sglang_client = sglang_client
|
||||
self.sglang_engine = sglang_engine
|
||||
self.accelerator = accelerator
|
||||
self.update_weight_buffer_size = update_weight_buffer_size
|
||||
|
||||
# Initialize distributed groups if needed
|
||||
self._init_distributed_groups()
|
||||
|
||||
def _init_distributed_groups(self):
|
||||
"""Initialize distributed process groups for weight updates."""
|
||||
if self.accelerator and self.accelerator.is_main_process:
|
||||
self._is_main_process = True
|
||||
self._group_name = "sglang_weight_sync"
|
||||
else:
|
||||
self._is_main_process = False
|
||||
|
||||
def get_param_infos(self, model: torch.nn.Module) -> list[ParamInfo]:
|
||||
"""
|
||||
Extract parameter information from the model.
|
||||
|
||||
Args:
|
||||
model: The model to extract parameters from
|
||||
|
||||
Returns:
|
||||
List of ParamInfo objects containing parameter metadata
|
||||
"""
|
||||
param_infos = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
# Calculate parameter size in bytes
|
||||
param_size = param.numel() * param.element_size()
|
||||
|
||||
# Extract tensor parallelism attributes if available
|
||||
attrs = {}
|
||||
if hasattr(param, "tensor_model_parallel"):
|
||||
attrs["tensor_model_parallel"] = param.tensor_model_parallel
|
||||
if hasattr(param, "partition_dim"):
|
||||
attrs["partition_dim"] = param.partition_dim
|
||||
if hasattr(param, "partition_stride"):
|
||||
attrs["partition_stride"] = param.partition_stride
|
||||
|
||||
# Determine source rank (simplified - in full implementation would handle TP/PP)
|
||||
src_rank = 0 if self.accelerator is None else self.accelerator.process_index
|
||||
|
||||
param_info = ParamInfo(
|
||||
name=name, dtype=param.dtype, shape=param.shape, attrs=attrs, size=param_size, src_rank=src_rank
|
||||
)
|
||||
param_infos.append(param_info)
|
||||
|
||||
return param_infos
|
||||
|
||||
def get_param_info_buckets(self, param_infos: list[ParamInfo]) -> list[list[ParamInfo]]:
|
||||
"""
|
||||
Group parameters into buckets based on memory constraints.
|
||||
|
||||
Args:
|
||||
param_infos: List of parameter information
|
||||
|
||||
Returns:
|
||||
List of parameter buckets, each respecting the buffer size limit
|
||||
"""
|
||||
param_info_buckets = [[]]
|
||||
buffer_size = 0
|
||||
oversized_params = []
|
||||
|
||||
for info in param_infos:
|
||||
# Calculate effective parameter size (accounting for tensor parallelism)
|
||||
tp_size = 1
|
||||
if hasattr(self.accelerator, "num_processes"):
|
||||
tp_size = self.accelerator.num_processes
|
||||
|
||||
# Handle expert parameters if present (MoE models)
|
||||
if ".experts." in info.name:
|
||||
# For expert parameters, we might have different TP size
|
||||
effective_param_size = info.size * tp_size
|
||||
else:
|
||||
effective_param_size = info.size * tp_size
|
||||
|
||||
# If a single parameter exceeds the buffer size, handle it separately
|
||||
if effective_param_size > self.update_weight_buffer_size:
|
||||
oversized_params.append(info)
|
||||
logger.warning(
|
||||
f"Parameter {info.name} ({effective_param_size / (1024**2):.2f}MB) exceeds buffer size ({self.update_weight_buffer_size / (1024**2):.2f}MB), will be processed individually"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if adding this parameter would exceed the buffer size
|
||||
if buffer_size + effective_param_size > self.update_weight_buffer_size and param_info_buckets[-1]:
|
||||
# Start a new bucket
|
||||
param_info_buckets.append([])
|
||||
buffer_size = 0
|
||||
|
||||
param_info_buckets[-1].append(info)
|
||||
buffer_size += effective_param_size
|
||||
|
||||
# Add oversized parameters as individual buckets
|
||||
for oversized_param in oversized_params:
|
||||
param_info_buckets.append([oversized_param])
|
||||
|
||||
# Remove empty buckets
|
||||
param_info_buckets = [bucket for bucket in param_info_buckets if bucket]
|
||||
|
||||
logger.info(
|
||||
f"Created {len(param_info_buckets)} parameter buckets with buffer size {self.update_weight_buffer_size / (1024**2):.2f}MB"
|
||||
)
|
||||
if oversized_params:
|
||||
logger.info(f"Found {len(oversized_params)} oversized parameters that will be processed individually")
|
||||
|
||||
return param_info_buckets
|
||||
|
||||
def _fix_param_name_for_sglang(self, name: str, extra_prefixes: Optional[list[str]] = None) -> str:
|
||||
"""Fix parameter names for SGLang compatibility."""
|
||||
extra_prefixes = extra_prefixes or []
|
||||
prefixes_to_remove = ["_checkpoint_wrapped_module."] + extra_prefixes
|
||||
|
||||
for prefix in prefixes_to_remove:
|
||||
name = name.replace(prefix, "")
|
||||
|
||||
return name
|
||||
|
||||
def _update_bucket_weights_server_mode(self, bucket_params: list[tuple[str, torch.Tensor]]) -> None:
|
||||
"""
|
||||
Update weights for a bucket of parameters in server mode.
|
||||
|
||||
Args:
|
||||
bucket_params: List of (name, parameter) tuples for this bucket
|
||||
"""
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
names = [name for name, _ in bucket_params]
|
||||
dtypes = [str(param.dtype) for _, param in bucket_params]
|
||||
shapes = [list(param.shape) for _, param in bucket_params]
|
||||
|
||||
# Use SGLang client's batch update API
|
||||
url = f"{self.sglang_client.base_url}/update_weights/"
|
||||
response = self.sglang_client.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": self._group_name,
|
||||
"flush_cache": False, # Don't flush cache for each bucket
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"SGLang bucket weight update failed: {response.status_code}, {response.text}")
|
||||
|
||||
logger.debug(f"Updated bucket with {len(names)} parameters in server mode")
|
||||
|
||||
def _update_bucket_weights_colocate_mode(self, bucket_params: list[tuple[str, torch.Tensor]]) -> None:
|
||||
"""
|
||||
Update weights for a bucket of parameters in colocate mode.
|
||||
|
||||
Args:
|
||||
bucket_params: List of (name, parameter) tuples for this bucket
|
||||
"""
|
||||
names = [name for name, _ in bucket_params]
|
||||
dtypes = [str(param.dtype) for _, param in bucket_params]
|
||||
shapes = [list(param.shape) for _, param in bucket_params]
|
||||
|
||||
# Single NCCL operation for all parameters in the bucket
|
||||
try:
|
||||
self.sglang_engine.update_weights_from_distributed(names, dtypes, shapes, self._group_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"SGLang weight update failed: {e}")
|
||||
logger.warning("Falling back to individual parameter updates")
|
||||
# Fallback to individual updates if batch update fails
|
||||
for name, dtype, shape in zip(names, dtypes, shapes):
|
||||
try:
|
||||
self.sglang_engine.update_weights_from_distributed([name], [dtype], [shape], self._group_name)
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to update parameter {name}: {e2}")
|
||||
|
||||
logger.debug(f"Updated bucket with {len(names)} parameters in colocate mode")
|
||||
|
||||
def _process_peft_parameters_bucketed(self, model: torch.nn.Module, gather_if_zero3) -> None:
|
||||
"""Process PEFT parameters using bucketed updates."""
|
||||
# Collect all PEFT parameters first
|
||||
peft_params = []
|
||||
param_infos = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
# Apply PEFT parameter name transformations
|
||||
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
|
||||
|
||||
# Skip certain PEFT parameters
|
||||
if "lora_dropout" in name or "modules_to_save.default.lm_head" in name:
|
||||
continue
|
||||
if "original_module" in name:
|
||||
continue
|
||||
if hasattr(model, "prefix") and model.prefix in name:
|
||||
continue
|
||||
|
||||
name = self._fix_param_name_for_sglang(name, extra_prefixes=["modules_to_save.default."])
|
||||
peft_params.append((name, param))
|
||||
|
||||
# Create param info for bucketing
|
||||
param_info = ParamInfo(
|
||||
name=name,
|
||||
dtype=param.dtype,
|
||||
shape=param.shape,
|
||||
attrs={},
|
||||
size=param.numel() * param.element_size(),
|
||||
src_rank=0,
|
||||
)
|
||||
param_infos.append(param_info)
|
||||
|
||||
if not peft_params:
|
||||
return
|
||||
|
||||
# Create buckets for PEFT parameters
|
||||
param_buckets = self.get_param_info_buckets(param_infos)
|
||||
|
||||
# Process each bucket
|
||||
with tqdm(total=len(param_buckets), desc="Updating PEFT parameter buckets") as pbar:
|
||||
for bucket_infos in param_buckets:
|
||||
bucket_params = []
|
||||
for param_info in bucket_infos:
|
||||
# Find the corresponding parameter
|
||||
for name, param in peft_params:
|
||||
if name == param_info.name:
|
||||
bucket_params.append((name, param))
|
||||
break
|
||||
|
||||
if bucket_params:
|
||||
if self.sglang_mode == "server":
|
||||
self._update_bucket_weights_server_mode(bucket_params)
|
||||
elif self.sglang_mode == "colocate":
|
||||
self._update_bucket_weights_colocate_mode(bucket_params)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
def _process_regular_parameters_bucketed(self, model: torch.nn.Module, gather_if_zero3) -> None:
|
||||
"""Process regular parameters using bucketed updates."""
|
||||
# Collect all regular parameters first
|
||||
regular_params = []
|
||||
param_infos = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
name = self._fix_param_name_for_sglang(name)
|
||||
regular_params.append((name, param))
|
||||
|
||||
# Create param info for bucketing
|
||||
param_info = ParamInfo(
|
||||
name=name,
|
||||
dtype=param.dtype,
|
||||
shape=param.shape,
|
||||
attrs={},
|
||||
size=param.numel() * param.element_size(),
|
||||
src_rank=0,
|
||||
)
|
||||
param_infos.append(param_info)
|
||||
|
||||
if not regular_params:
|
||||
return
|
||||
|
||||
# Create buckets for regular parameters
|
||||
param_buckets = self.get_param_info_buckets(param_infos)
|
||||
|
||||
# Process each bucket with parameter gathering
|
||||
with tqdm(total=len(param_buckets), desc="Updating regular parameter buckets") as pbar:
|
||||
for bucket_infos in param_buckets:
|
||||
bucket_params = []
|
||||
params_to_gather = []
|
||||
|
||||
for param_info in bucket_infos:
|
||||
# Find the corresponding parameter
|
||||
for name, param in regular_params:
|
||||
if name == param_info.name:
|
||||
bucket_params.append((name, param))
|
||||
params_to_gather.append(param)
|
||||
break
|
||||
|
||||
if bucket_params:
|
||||
# Gather all parameters in this bucket at once
|
||||
with gather_if_zero3(params_to_gather):
|
||||
if self.sglang_mode == "server":
|
||||
self._update_bucket_weights_server_mode(bucket_params)
|
||||
elif self.sglang_mode == "colocate":
|
||||
self._update_bucket_weights_colocate_mode(bucket_params)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
def _flush_sglang_cache(self):
|
||||
"""Flush SGLang cache after all weight updates."""
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
self.sglang_client.flush_cache()
|
||||
elif self.sglang_mode == "colocate":
|
||||
self.sglang_engine.reset_prefix_cache()
|
||||
|
||||
def update_model_weights(self, deepspeed_plugin=None) -> None:
|
||||
"""
|
||||
Update SGLang model weights using slime-style bucketed approach.
|
||||
|
||||
This is the main entry point that handles the complete weight update process
|
||||
with memory-aware bucketing and proper distributed coordination.
|
||||
"""
|
||||
logger.info("Starting slime-style bucketed weight update")
|
||||
start_time = time.time()
|
||||
|
||||
# Setup gathering context for DeepSpeed ZeRO-3
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
if zero_stage_3:
|
||||
import deepspeed
|
||||
|
||||
gather_if_zero3 = deepspeed.zero.GatheredParameters
|
||||
else:
|
||||
gather_if_zero3 = nullcontext
|
||||
|
||||
# Clear GPU memory before starting
|
||||
self._clear_gpu_memory()
|
||||
|
||||
if is_peft_model(self.model):
|
||||
# Handle PEFT models with adapter merging
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
# Process PEFT parameters with bucketing
|
||||
if hasattr(self.accelerator.state, "fsdp_plugin") and self.accelerator.state.fsdp_plugin is not None:
|
||||
# FSDP handling would go here - simplified for now
|
||||
self._process_peft_parameters_bucketed(self.model, gather_if_zero3)
|
||||
else:
|
||||
# DeepSpeed ZeRO-3 with PEFT
|
||||
self._process_peft_parameters_bucketed(self.model, gather_if_zero3)
|
||||
|
||||
# Unmerge adapters after update
|
||||
self.model.unmerge_adapter()
|
||||
else:
|
||||
# Handle regular models without PEFT
|
||||
if hasattr(self.accelerator.state, "fsdp_plugin") and self.accelerator.state.fsdp_plugin is not None:
|
||||
# FSDP handling would go here - simplified for now
|
||||
self._process_regular_parameters_bucketed(self.model, gather_if_zero3)
|
||||
else:
|
||||
# Regular parameter processing with bucketing
|
||||
self._process_regular_parameters_bucketed(self.model, gather_if_zero3)
|
||||
|
||||
# Flush cache once at the end
|
||||
self._flush_sglang_cache()
|
||||
|
||||
# Clear memory after updates
|
||||
self._clear_gpu_memory()
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"Completed slime-style weight update in {elapsed_time:.2f} seconds")
|
||||
|
||||
def _clear_gpu_memory(self):
|
||||
"""Clear GPU memory to prevent OOM issues."""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_memory_info(self) -> dict[str, Any]:
|
||||
"""Get current GPU memory information."""
|
||||
if not torch.cuda.is_available():
|
||||
return {"gpu": "none", "total_GB": 0, "free_GB": 0, "used_GB": 0}
|
||||
|
||||
free, total = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
return {
|
||||
"gpu": str(torch.cuda.current_device()),
|
||||
"total_GB": round(total / (1024**3), 2),
|
||||
"free_GB": round(free / (1024**3), 2),
|
||||
"used_GB": round((total - free) / (1024**3), 2),
|
||||
}
|
@ -38,6 +38,7 @@ _uvicorn_available = _is_package_available("uvicorn")
|
||||
_vllm_available = _is_package_available("vllm")
|
||||
_vllm_ascend_available = _is_package_available("vllm_ascend")
|
||||
_joblib_available = _is_package_available("joblib")
|
||||
_sglang_available = _is_package_available("sglang")
|
||||
|
||||
|
||||
def is_deepspeed_available() -> bool:
|
||||
@ -92,6 +93,53 @@ def is_joblib_available() -> bool:
|
||||
return _joblib_available
|
||||
|
||||
|
||||
def is_sglang_available() -> bool:
|
||||
"""
|
||||
Check if SGLang is available and can be imported successfully.
|
||||
|
||||
This function performs comprehensive checks to ensure SGLang is not only installed
|
||||
but also functionally importable, which can fail due to missing dependencies,
|
||||
GPU/CUDA requirements, or version incompatibilities.
|
||||
|
||||
Returns:
|
||||
bool: True if SGLang is available and all required modules can be imported
|
||||
"""
|
||||
if not _sglang_available:
|
||||
return False
|
||||
|
||||
# Check if core SGLang modules can be imported
|
||||
# These are the essential modules used by TRL's SGLang integration
|
||||
required_modules = [
|
||||
("sglang.srt.entrypoints.http_server", "launch_server"),
|
||||
("sglang.srt.server_args", "ServerArgs"),
|
||||
("sglang.srt.utils", "kill_process_tree"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_modules:
|
||||
try:
|
||||
module = __import__(module_name, fromlist=[attr_name])
|
||||
if not hasattr(module, attr_name):
|
||||
return False
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
return False
|
||||
except Exception:
|
||||
# Catch other potential issues like CUDA/GPU initialization errors,
|
||||
# missing Triton kernels, version mismatches, etc.
|
||||
return False
|
||||
|
||||
# Additional runtime check to ensure SGLang server can actually be instantiated
|
||||
# This catches issues with GPU availability, CUDA compatibility, etc.
|
||||
try:
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
# Try to create a minimal ServerArgs instance to verify basic functionality
|
||||
_ = ServerArgs(model_path="dummy")
|
||||
return True
|
||||
except Exception:
|
||||
# If ServerArgs creation fails, SGLang likely has environment issues
|
||||
return False
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
|
530
trl/scripts/sglang_serve.py
Normal file
530
trl/scripts/sglang_serve.py
Normal file
@ -0,0 +1,530 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Optional
|
||||
|
||||
from transformers import is_vision_available
|
||||
|
||||
from trl import TrlParser
|
||||
from trl.import_utils import (
|
||||
is_fastapi_available,
|
||||
is_pydantic_available,
|
||||
is_requests_available,
|
||||
is_sglang_available,
|
||||
is_uvicorn_available,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_available():
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
if is_pydantic_available():
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
if is_requests_available():
|
||||
pass
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
pass
|
||||
|
||||
|
||||
if is_sglang_available():
|
||||
from ..extras.sglang_engine_adapter import SGLangEngine
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# We use CUDA with multiprocessing, so we must use the 'spawn' start method
|
||||
os.environ["SGLANG_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
r"""
|
||||
Arguments for the SGLang serve script.
|
||||
|
||||
Args:
|
||||
model (`str`):
|
||||
Model name or path to load the model from.
|
||||
revision (`str` or `None`, *optional*, defaults to `None`):
|
||||
Revision to use for the model.
|
||||
tensor_parallel_size (`int`, *optional*, defaults to `1`):
|
||||
Number of tensor parallel workers to use.
|
||||
data_parallel_size (`int`, *optional*, defaults to `1`):
|
||||
Number of data parallel workers to use.
|
||||
host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
Host address to run the server on.
|
||||
port (`int`, *optional*, defaults to `8001`):
|
||||
Port to run the server on.
|
||||
gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
||||
Ratio of GPU memory to reserve for the model.
|
||||
dtype (`str`, *optional*, defaults to `"auto"`):
|
||||
Data type to use for SGLang generation.
|
||||
max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum model length to use.
|
||||
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
|
||||
Whether to enable prefix caching in SGLang.
|
||||
enforce_eager (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enforce eager execution.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code when loading models.
|
||||
log_level (`str`, *optional*, defaults to `"info"`):
|
||||
Log level for uvicorn.
|
||||
"""
|
||||
|
||||
model: str = field(
|
||||
metadata={"help": "Model name or path to load the model from."},
|
||||
)
|
||||
revision: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Revision to use for the model."},
|
||||
)
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
data_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of data parallel workers to use."},
|
||||
)
|
||||
host: str = field(
|
||||
default="0.0.0.0",
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: int = field(
|
||||
default=8001,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: float = field(
|
||||
default=0.9,
|
||||
metadata={"help": "Ratio of GPU memory to reserve for the model weights, activations, and KV cache."},
|
||||
)
|
||||
dtype: str = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type to use for SGLang generation."},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Maximum model length to use."},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether to enable prefix caching in SGLang."},
|
||||
)
|
||||
enforce_eager: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to enforce eager execution."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to trust remote code when loading models."},
|
||||
)
|
||||
log_level: str = field(
|
||||
default="info",
|
||||
metadata={"help": "Log level for uvicorn."},
|
||||
)
|
||||
|
||||
|
||||
def sglang_worker(
|
||||
script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection
|
||||
) -> None:
|
||||
"""
|
||||
Worker process for SGLang engine using improved architecture.
|
||||
|
||||
Args:
|
||||
script_args: Configuration arguments
|
||||
data_parallel_rank: Rank of this worker in data parallel group
|
||||
master_port: Port for distributed communication
|
||||
connection: Pipe connection to parent process
|
||||
"""
|
||||
|
||||
# Set environment variables for data parallelism
|
||||
os.environ["SGLANG_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["SGLANG_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["SGLANG_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
# Create SGLang engine using improved adapter
|
||||
port = script_args.port + data_parallel_rank
|
||||
nccl_port = master_port + 1000 + data_parallel_rank # Separate NCCL ports
|
||||
dist_init_addr = f"{script_args.host}:{master_port}"
|
||||
|
||||
# Add required attributes to script_args for SGLangEngine
|
||||
script_args.sglang_model_path = script_args.model
|
||||
script_args.sglang_host = script_args.host
|
||||
script_args.sglang_tensor_parallel_size = script_args.tensor_parallel_size
|
||||
script_args.sglang_data_parallel_size = script_args.data_parallel_size
|
||||
script_args.sglang_pipeline_parallel_size = 1
|
||||
script_args.sglang_expert_parallel_size = 1
|
||||
script_args.sglang_num_gpus_per_node = 8 # Default
|
||||
script_args.colocate = True
|
||||
script_args.offload = False
|
||||
|
||||
try:
|
||||
# Create SGLang engine
|
||||
sglang_engine = SGLangEngine(
|
||||
args=script_args,
|
||||
rank=data_parallel_rank,
|
||||
dist_init_addr=dist_init_addr,
|
||||
port=port,
|
||||
nccl_port=nccl_port,
|
||||
)
|
||||
|
||||
# Send ready signal
|
||||
connection.send({"status": "ready", "url": f"http://{script_args.host}:{port}"})
|
||||
|
||||
# Main loop to handle commands
|
||||
while True:
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
if command["type"] == "init_communicator":
|
||||
sglang_engine.init_process_group(**command["kwargs"])
|
||||
connection.send({"status": "ok"})
|
||||
elif command["type"] == "init_weights_update_group":
|
||||
sglang_engine.init_process_group(**command["kwargs"])
|
||||
connection.send({"status": "ok"})
|
||||
elif command["type"] == "update_weights":
|
||||
sglang_engine.update_weights_from_distributed(**command["kwargs"])
|
||||
connection.send({"status": "ok"})
|
||||
elif command["type"] == "generate":
|
||||
result = sglang_engine.generate(**command["kwargs"])
|
||||
connection.send(result)
|
||||
elif command["type"] == "flush_cache":
|
||||
sglang_engine.reset_prefix_cache()
|
||||
connection.send({"status": "ok"})
|
||||
elif command["type"] == "pause_generation":
|
||||
sglang_engine.pause_generation()
|
||||
connection.send({"status": "ok"})
|
||||
elif command["type"] == "continue_generation":
|
||||
sglang_engine.continue_generation()
|
||||
connection.send({"status": "ok"})
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
connection.send({"status": "error", "message": str(e)})
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
if "sglang_engine" in locals():
|
||||
sglang_engine.shutdown()
|
||||
|
||||
|
||||
def chunk_list(lst: list, n: int) -> list[list]:
|
||||
"""Split list into n evenly distributed sublists."""
|
||||
k, r = divmod(len(lst), n)
|
||||
return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)]
|
||||
|
||||
|
||||
def main(script_args: ScriptArguments):
|
||||
if not is_fastapi_available():
|
||||
raise ImportError(
|
||||
"FastAPI is required to run the SGLang serve script. Please install it using `pip install fastapi`."
|
||||
)
|
||||
|
||||
if not is_pydantic_available():
|
||||
raise ImportError(
|
||||
"Pydantic is required to run the SGLang serve script. Please install it using `pip install pydantic`."
|
||||
)
|
||||
|
||||
if not is_uvicorn_available():
|
||||
raise ImportError(
|
||||
"Uvicorn is required to run the SGLang serve script. Please install it using `pip install uvicorn`."
|
||||
)
|
||||
|
||||
if not is_sglang_available():
|
||||
raise ImportError(
|
||||
"SGLang is required to run the SGLang serve script. Please install it using `pip install sglang`."
|
||||
)
|
||||
|
||||
# Spawn data parallel workers
|
||||
master_port = 29500 # Fixed port for DP communication
|
||||
connections = []
|
||||
processes = []
|
||||
worker_urls = []
|
||||
|
||||
for data_parallel_rank in range(script_args.data_parallel_size):
|
||||
parent_connection, child_connection = Pipe()
|
||||
process = Process(target=sglang_worker, args=(script_args, data_parallel_rank, master_port, child_connection))
|
||||
process.start()
|
||||
connections.append(parent_connection)
|
||||
processes.append(process)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Wait for all workers to be ready
|
||||
for connection in connections:
|
||||
msg = connection.recv()
|
||||
if msg.get("status") == "ready":
|
||||
worker_urls.append(msg["url"])
|
||||
else:
|
||||
raise RuntimeError(f"Worker failed to start: {msg}")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown workers
|
||||
for connection in connections:
|
||||
connection.send({"type": "shutdown"})
|
||||
|
||||
# Wait for processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=10)
|
||||
if process.is_alive():
|
||||
logger.warning(f"Process {process} is still alive, terminating...")
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Define endpoints
|
||||
@app.get("/health/")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/get_world_size/")
|
||||
async def get_world_size():
|
||||
"""Get the total world size."""
|
||||
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size}
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompts: list[str]
|
||||
images: Optional[list[str]] = None
|
||||
sampling_params: dict = field(default_factory=dict)
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
completion_ids: list[list[int]]
|
||||
|
||||
@app.post("/generate/", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
"""Generate completions for the provided prompts."""
|
||||
# Distribute prompts across DP workers
|
||||
chunked_prompts = chunk_list(request.prompts, script_args.data_parallel_size)
|
||||
|
||||
# Send to workers
|
||||
for connection, prompts in zip(connections, chunked_prompts):
|
||||
if not prompts:
|
||||
prompts = ["<placeholder>"] # SGLang requires at least one prompt
|
||||
|
||||
kwargs = {
|
||||
"text": prompts,
|
||||
"sampling_params": request.sampling_params,
|
||||
}
|
||||
connection.send({"type": "generate", "kwargs": kwargs})
|
||||
|
||||
# Collect results
|
||||
all_outputs = []
|
||||
for connection, prompts in zip(connections, chunked_prompts):
|
||||
if prompts: # Only collect from workers that had real prompts
|
||||
output = connection.recv()
|
||||
all_outputs.append(output)
|
||||
|
||||
# Combine results
|
||||
completion_ids = []
|
||||
for output in all_outputs:
|
||||
for item in output.get("text_outputs", []):
|
||||
# Extract token IDs from the output
|
||||
# This will need adjustment based on SGLang's actual output format
|
||||
completion_ids.append(item.get("token_ids", []))
|
||||
|
||||
return {"completion_ids": completion_ids}
|
||||
|
||||
class InitCommunicatorRequest(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
world_size: int
|
||||
client_device_uuid: str
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
async def init_communicator(request: InitCommunicatorRequest):
|
||||
"""Initialize the weight update communicator."""
|
||||
world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1
|
||||
|
||||
# Initialize communicator on all workers
|
||||
for i, connection in enumerate(connections):
|
||||
kwargs = {
|
||||
"master_address": request.host,
|
||||
"master_port": request.port,
|
||||
"rank_offset": i,
|
||||
"world_size": world_size,
|
||||
"group_name": "weight_sync",
|
||||
"backend": "nccl",
|
||||
}
|
||||
connection.send({"type": "init_communicator", "kwargs": kwargs})
|
||||
|
||||
# Wait for all to complete
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Communicator initialized"}
|
||||
|
||||
class InitWeightsUpdateGroupRequest(BaseModel):
|
||||
master_address: str
|
||||
master_port: int
|
||||
rank_offset: int
|
||||
world_size: int
|
||||
group_name: str
|
||||
backend: str
|
||||
|
||||
@app.post("/init_weights_update_group/")
|
||||
async def init_weights_update_group(request: InitWeightsUpdateGroupRequest):
|
||||
"""Initialize the weight update group for distributed training."""
|
||||
kwargs = {
|
||||
"master_address": request.master_address,
|
||||
"master_port": request.master_port,
|
||||
"rank_offset": request.rank_offset,
|
||||
"world_size": request.world_size,
|
||||
"group_name": request.group_name,
|
||||
"backend": request.backend,
|
||||
}
|
||||
|
||||
# Send to all workers
|
||||
for connection in connections:
|
||||
connection.send({"type": "init_weights_update_group", "kwargs": kwargs})
|
||||
|
||||
# Wait for all to complete
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Weight update group initialized"}
|
||||
|
||||
class UpdateWeightsRequest(BaseModel):
|
||||
names: list[str]
|
||||
dtypes: list[str]
|
||||
shapes: list[list[int]]
|
||||
|
||||
@app.post("/update_weights/")
|
||||
async def update_weights(request: UpdateWeightsRequest):
|
||||
"""Update model weights."""
|
||||
kwargs = {
|
||||
"names": request.names,
|
||||
"dtypes": request.dtypes,
|
||||
"shapes": request.shapes,
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": True,
|
||||
}
|
||||
|
||||
# Send to all workers
|
||||
for connection in connections:
|
||||
connection.send({"type": "update_weights", "kwargs": kwargs})
|
||||
|
||||
# Wait for all to complete
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Weights updated"}
|
||||
|
||||
class UpdateWeightsFromDistributedRequest(BaseModel):
|
||||
names: list[str]
|
||||
dtypes: list[str]
|
||||
shapes: list[list[int]]
|
||||
group_name: str = "weight_sync"
|
||||
flush_cache: bool = False
|
||||
|
||||
@app.post("/update_weights_from_distributed/")
|
||||
async def update_weights_from_distributed(request: UpdateWeightsFromDistributedRequest):
|
||||
"""Update model weights from distributed training using NCCL broadcast."""
|
||||
kwargs = {
|
||||
"names": request.names,
|
||||
"dtypes": request.dtypes,
|
||||
"shapes": request.shapes,
|
||||
"group_name": request.group_name,
|
||||
"flush_cache": request.flush_cache,
|
||||
}
|
||||
|
||||
# Send to all workers
|
||||
for connection in connections:
|
||||
connection.send({"type": "update_weights", "kwargs": kwargs})
|
||||
|
||||
# Wait for all to complete
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Distributed weights updated"}
|
||||
|
||||
@app.post("/flush_cache/")
|
||||
async def flush_cache():
|
||||
"""Flush the cache."""
|
||||
for connection in connections:
|
||||
connection.send({"type": "flush_cache"})
|
||||
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Cache flushed"}
|
||||
|
||||
@app.post("/close_communicator/")
|
||||
async def close_communicator():
|
||||
"""Close the weight update communicator."""
|
||||
# SGLang doesn't need explicit communicator closing
|
||||
return {"message": "Communicator closed"}
|
||||
|
||||
@app.post("/pause_generation/")
|
||||
async def pause_generation():
|
||||
"""Pause generation on all SGLang workers."""
|
||||
for connection in connections:
|
||||
connection.send({"type": "pause_generation"})
|
||||
|
||||
# Wait for all to complete
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Generation paused"}
|
||||
|
||||
@app.post("/continue_generation/")
|
||||
async def continue_generation():
|
||||
"""Continue generation on all SGLang workers."""
|
||||
for connection in connections:
|
||||
connection.send({"type": "continue_generation"})
|
||||
|
||||
# Wait for all to complete
|
||||
for connection in connections:
|
||||
connection.recv()
|
||||
|
||||
return {"message": "Generation continued"}
|
||||
|
||||
# Start the server
|
||||
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level)
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser(
|
||||
"sglang-serve", help="Run the SGLang serve script", dataclass_types=ScriptArguments
|
||||
)
|
||||
else:
|
||||
parser = TrlParser(ScriptArguments)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
(script_args,) = parser.parse_args_and_config()
|
||||
main(script_args)
|
@ -445,6 +445,110 @@ class GRPOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control generation acceleration powered by SGLang
|
||||
use_sglang: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use SGLang for generating completions. If set to `True`, the trainer will use SGLang for "
|
||||
"generation instead of the default model.generate(). Requires `sglang` to be installed."
|
||||
},
|
||||
)
|
||||
sglang_mode: str = field(
|
||||
default="server",
|
||||
metadata={
|
||||
"help": "Mode to use for SGLang integration when `use_sglang` is set to `True`. Must be one of `server` or "
|
||||
"`'colocate'`. `'server'`: The trainer will send generation requests to a separate SGLang server. Make sure "
|
||||
"a TRL SGLang server is running (start with `trl sglang-serve`). `'colocate'`: SGLang will run in the same "
|
||||
"process and share the training GPUs. This avoids the need for a separate server but may cause resource "
|
||||
"contention with training."
|
||||
},
|
||||
)
|
||||
sglang_server_base_url: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Base URL for the SGLang server (e.g., 'http://localhost:8001'). If provided, `sglang_server_host` "
|
||||
"and `sglang_server_port` are ignored."
|
||||
},
|
||||
)
|
||||
sglang_server_host: str = field(
|
||||
default="0.0.0.0",
|
||||
metadata={"help": "Host of the SGLang server to connect to. Ignored if sglang_server_base_url is provided."},
|
||||
)
|
||||
sglang_server_port: int = field(
|
||||
default=8001,
|
||||
metadata={"help": "Port of the SGLang server to connect to. Ignored if sglang_server_base_url is provided."},
|
||||
)
|
||||
sglang_server_timeout: float = field(
|
||||
default=240.0,
|
||||
metadata={
|
||||
"help": "Total timeout duration in seconds to wait for the SGLang server to be up. If the server is not up "
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control colocated SGLang execution (only used when `sglang_mode` is `"colocate"`)
|
||||
sglang_gpu_memory_utilization: float = field(
|
||||
default=0.3,
|
||||
metadata={
|
||||
"help": "Control the GPU memory utilization for SGLang. This setting only applies when `sglang_mode` is set "
|
||||
"to `'colocate'`. If you are using `sglang_mode='server'`, this parameter must be passed separately when "
|
||||
"launching the SGLang server via the `--gpu-memory-utilization` flag."
|
||||
},
|
||||
)
|
||||
sglang_tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Control the tensor parallel size for SGLang. This setting only applies when `sglang_mode` is set "
|
||||
"to `'colocate'`. If you are using `sglang_mode='server'`, this parameter must be passed separately when "
|
||||
"launching the SGLang server via the `--tensor-parallel-size` flag."
|
||||
},
|
||||
)
|
||||
sglang_pipeline_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Control the pipeline parallel size for SGLang. This setting only applies when `sglang_mode` is set "
|
||||
"to `'colocate'`. If you are using `sglang_mode='server'`, this parameter must be passed separately when "
|
||||
"launching the SGLang server via the `--pipeline-parallel-size` flag."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters for slime-style bucketed weight updates
|
||||
use_sglang_bucketed_updates: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to use slime-style bucketed weight updates for SGLang. This provides better memory management "
|
||||
"and performance for large models by grouping parameters into memory-aware buckets."
|
||||
},
|
||||
)
|
||||
sglang_update_weight_buffer_size: int = field(
|
||||
default=512 * 1024**2, # 512MB
|
||||
metadata={
|
||||
"help": "Buffer size for SGLang weight updates in bytes. Parameters are grouped into buckets that don't exceed "
|
||||
"this size to prevent memory issues. Default is 512MB. Increase for better batching, decrease if encountering OOM."
|
||||
},
|
||||
)
|
||||
sglang_pause_generation_during_update: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to pause SGLang generation during weight updates to prevent race conditions and ensure "
|
||||
"consistent inference results. Recommended for production use."
|
||||
},
|
||||
)
|
||||
sglang_data_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Control the data parallel size for SGLang. This setting only applies when `sglang_mode` is set "
|
||||
"to `'colocate'`. Defaults to 1."
|
||||
},
|
||||
)
|
||||
sglang_enable_dp_attention: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable distributed attention for SGLang when using data parallelism. Required when "
|
||||
"`sglang_data_parallel_size` > 1."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
beta: float = field(
|
||||
default=0.0,
|
||||
@ -626,3 +730,15 @@ class GRPOConfig(TrainingArguments):
|
||||
|
||||
if self.delta is not None and self.use_liger_loss:
|
||||
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")
|
||||
|
||||
# SGLang validation
|
||||
if self.use_sglang and self.use_vllm:
|
||||
raise ValueError("Cannot use both SGLang and vLLM at the same time.")
|
||||
|
||||
if self.use_sglang:
|
||||
if self.sglang_mode not in ["server", "colocate"]:
|
||||
raise ValueError(f"sglang_mode must be either 'server' or 'colocate', got '{self.sglang_mode}'")
|
||||
|
||||
if self.sglang_mode == "colocate":
|
||||
if self.sglang_data_parallel_size > 1 and not self.sglang_enable_dp_attention:
|
||||
raise ValueError("sglang_enable_dp_attention must be True when sglang_data_parallel_size > 1")
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
@ -53,8 +54,10 @@ from transformers.utils import is_datasets_available, is_flash_attn_2_available,
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from ..extras.profiling import profiling_context, profiling_decorator
|
||||
from ..extras.sglang_client import SGLangClient
|
||||
from ..extras.sglang_weight_utils import SGLangWeightUpdater
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_liger_kernel_available, is_vllm_available
|
||||
from ..import_utils import is_liger_kernel_available, is_sglang_available, is_vllm_available
|
||||
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
|
||||
from ..models.utils import _ForwardRedirection
|
||||
from .callbacks import SyncRefModelCallback
|
||||
@ -80,6 +83,9 @@ if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
if is_sglang_available():
|
||||
from ..extras.sglang_engine_adapter import SGLangEngine
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@ -87,6 +93,8 @@ if is_wandb_available():
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RepeatSampler(Sampler):
|
||||
"""
|
||||
@ -667,6 +675,10 @@ class GRPOTrainer(Trainer):
|
||||
self.vllm_mode = args.vllm_mode
|
||||
self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
|
||||
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
|
||||
self.use_sglang = args.use_sglang
|
||||
self.sglang_mode = args.sglang_mode
|
||||
self.sglang_gpu_memory_utilization = args.sglang_gpu_memory_utilization # only applies to colocation mode
|
||||
self.sglang_tensor_parallel_size = args.sglang_tensor_parallel_size # only applies to colocation mode
|
||||
self.use_liger_loss = args.use_liger_loss
|
||||
self.loss_type = args.loss_type
|
||||
self.scale_rewards = args.scale_rewards
|
||||
@ -857,7 +869,108 @@ class GRPOTrainer(Trainer):
|
||||
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
||||
# synchronize all processes after vLLM has been fully initialized.
|
||||
self.accelerator.wait_for_everyone()
|
||||
else:
|
||||
|
||||
if self.use_sglang:
|
||||
if not is_sglang_available():
|
||||
raise ImportError(
|
||||
"SGLang is not available and `use_sglang` is set to True. Please install SGLang with "
|
||||
"`pip install sglang` to use it."
|
||||
)
|
||||
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
if args.sglang_server_base_url is not None:
|
||||
base_url = args.sglang_server_base_url
|
||||
else:
|
||||
base_url = f"http://{args.sglang_server_host}:{args.sglang_server_port}"
|
||||
self.sglang_client = SGLangClient(base_url=base_url, connection_timeout=args.sglang_server_timeout)
|
||||
self.sglang_client.init_communicator(device=torch.cuda.current_device())
|
||||
|
||||
elif self.sglang_mode == "colocate":
|
||||
# Make sure sglang_tensor_parallel_size group size evenly divides the world size
|
||||
if not self.accelerator.num_processes % self.sglang_tensor_parallel_size == 0:
|
||||
raise ValueError(
|
||||
f"sglang_tensor_parallel_size ({self.sglang_tensor_parallel_size}) must divide world size "
|
||||
f"({self.accelerator.num_processes}) evenly."
|
||||
)
|
||||
|
||||
if self.sglang_tensor_parallel_size > 1:
|
||||
# Create subgroups of ranks for TP
|
||||
self.sglang_tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
|
||||
[
|
||||
list(
|
||||
range(i * self.sglang_tensor_parallel_size, (i + 1) * self.sglang_tensor_parallel_size)
|
||||
)
|
||||
for i in range(self.accelerator.num_processes // self.sglang_tensor_parallel_size)
|
||||
]
|
||||
)
|
||||
|
||||
# Set environment variables for SGLang distributed training
|
||||
os.environ["RANK"] = str(self.accelerator.process_index)
|
||||
os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index)
|
||||
os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
|
||||
os.environ["MASTER_ADDR"] = getattr(self.accelerator.state, "main_process_ip", "localhost")
|
||||
os.environ["MASTER_PORT"] = str(getattr(self.accelerator.state, "main_process_port", 29500))
|
||||
|
||||
# Initialize SGLang engine (colocate mode) using improved adapter
|
||||
# Add required attributes for SGLangEngine
|
||||
args.sglang_model_path = model.name_or_path
|
||||
args.sglang_host = "127.0.0.1"
|
||||
args.sglang_num_gpus_per_node = torch.cuda.device_count()
|
||||
args.colocate = True
|
||||
args.offload = False
|
||||
|
||||
# Create SGLang engine
|
||||
port = 8001 + self.accelerator.process_index # Different port per process
|
||||
nccl_port = 29500 + self.accelerator.process_index
|
||||
dist_init_addr = f"{getattr(self.accelerator.state, 'main_process_ip', 'localhost')}:{getattr(self.accelerator.state, 'main_process_port', 29500)}"
|
||||
|
||||
self.sglang_engine = SGLangEngine(
|
||||
args=args,
|
||||
rank=self.accelerator.process_index,
|
||||
dist_init_addr=dist_init_addr,
|
||||
port=port,
|
||||
nccl_port=nccl_port,
|
||||
)
|
||||
|
||||
# Initialize weight update group (required before weight updates)
|
||||
try:
|
||||
self.sglang_engine.init_process_group(
|
||||
master_address=getattr(self.accelerator.state, "main_process_ip", "localhost"),
|
||||
master_port=getattr(self.accelerator.state, "main_process_port", 29500),
|
||||
rank_offset=self.accelerator.process_index,
|
||||
world_size=self.accelerator.num_processes,
|
||||
group_name="sglang_weight_sync",
|
||||
backend="nccl",
|
||||
)
|
||||
logger.info("Initialized SGLang weight update group")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SGLang weight update group: {e}")
|
||||
logger.warning(
|
||||
"This may cause issues with weight updates. Consider using a patched SGLang version or disabling bucketed updates."
|
||||
)
|
||||
|
||||
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
|
||||
|
||||
# Initialize slime-style weight updater if enabled
|
||||
if args.use_sglang_bucketed_updates:
|
||||
self.sglang_weight_updater = SGLangWeightUpdater(
|
||||
model=self.model,
|
||||
sglang_mode=self.sglang_mode,
|
||||
sglang_client=getattr(self, "sglang_client", None),
|
||||
sglang_engine=getattr(self, "sglang_engine", None),
|
||||
accelerator=self.accelerator,
|
||||
update_weight_buffer_size=args.sglang_update_weight_buffer_size,
|
||||
)
|
||||
logger.info(
|
||||
f"Initialized slime-style SGLang weight updater with buffer size: {args.sglang_update_weight_buffer_size / (1024**2):.2f}MB"
|
||||
)
|
||||
else:
|
||||
self.sglang_weight_updater = None
|
||||
|
||||
# Synchronize all processes after SGLang initialization
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
if not self.use_vllm and not self.use_sglang:
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": self.max_completion_length,
|
||||
"do_sample": True,
|
||||
@ -1255,6 +1368,254 @@ class GRPOTrainer(Trainer):
|
||||
elif self.vllm_mode == "colocate":
|
||||
self.llm.reset_prefix_cache()
|
||||
|
||||
def _fix_param_name_to_sglang(self, name, extra_prefixes: Optional[list[str]] = None):
|
||||
"""Fix parameter names for SGLang compatibility."""
|
||||
extra_prefixes = extra_prefixes or []
|
||||
prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
|
||||
for prefix in prefixes:
|
||||
name = name.replace(prefix, "")
|
||||
return name
|
||||
|
||||
def _sync_fsdp1_params_to_sglang(self, module: nn.Module, prefix: str = "", visited=None, batch_params=None):
|
||||
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with SGLang."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if batch_params is None:
|
||||
batch_params = {"names": [], "dtypes": [], "shapes": [], "params": []}
|
||||
is_root_call = True
|
||||
else:
|
||||
is_root_call = False
|
||||
|
||||
for child_name, child_module in module.named_children():
|
||||
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
|
||||
self._sync_fsdp1_params_to_sglang(
|
||||
child_module, prefix=child_prefix, visited=visited, batch_params=batch_params
|
||||
)
|
||||
|
||||
if isinstance(module, FSDP):
|
||||
with FSDP.summon_full_params(module, recurse=False, writeback=False):
|
||||
for param_name, param in module.named_parameters():
|
||||
full_name = f"{prefix}.{param_name}" if prefix else param_name
|
||||
full_name = self._fix_param_name_to_sglang(full_name, extra_prefixes=["_fsdp_wrapped_module."])
|
||||
|
||||
if full_name in visited:
|
||||
continue
|
||||
visited.add(full_name)
|
||||
|
||||
# Collect parameters for batch update
|
||||
batch_params["names"].append(full_name)
|
||||
batch_params["dtypes"].append(str(param.data.dtype))
|
||||
batch_params["shapes"].append(list(param.data.shape))
|
||||
batch_params["params"].append(param)
|
||||
|
||||
# If this is the root call, perform the batched update
|
||||
if is_root_call and batch_params["names"]:
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
# Use SGLang client's batch update API
|
||||
url = f"{self.sglang_client.base_url}/update_weights/"
|
||||
response = self.sglang_client.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": batch_params["names"],
|
||||
"dtypes": batch_params["dtypes"],
|
||||
"shapes": batch_params["shapes"],
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": False,
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"SGLang FSDP1 weight update failed: {response.status_code}, {response.text}")
|
||||
elif self.sglang_mode == "colocate":
|
||||
# Single NCCL operation for all FSDP parameters
|
||||
self.sglang_engine.update_weights_from_distributed(
|
||||
batch_params["names"], batch_params["dtypes"], batch_params["shapes"], "weight_sync"
|
||||
)
|
||||
|
||||
def _sync_fsdp2_params_to_sglang(self, module: nn.Module):
|
||||
"""Sync FSDP2 parameters to SGLang using batched updates."""
|
||||
names, dtypes, shapes, params = [], [], [], []
|
||||
|
||||
for name, param in module.state_dict().items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
param = param.full_tensor()
|
||||
|
||||
names.append(name)
|
||||
dtypes.append(str(param.dtype))
|
||||
shapes.append(list(param.shape))
|
||||
params.append(param)
|
||||
|
||||
# Batched update for all FSDP2 parameters
|
||||
if names:
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
# Use SGLang client's batch update API
|
||||
url = f"{self.sglang_client.base_url}/update_weights/"
|
||||
response = self.sglang_client.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": False,
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"SGLang FSDP2 weight update failed: {response.status_code}, {response.text}")
|
||||
elif self.sglang_mode == "colocate":
|
||||
# Single NCCL operation for all FSDP2 parameters
|
||||
self.sglang_engine.update_weights_from_distributed(names, dtypes, shapes, "weight_sync")
|
||||
|
||||
@profiling_decorator
|
||||
def _move_model_to_sglang(self):
|
||||
"""Move model weights to SGLang for inference."""
|
||||
# Use slime-style bucketed updates if available
|
||||
if self.sglang_weight_updater is not None:
|
||||
# Pause generation during weight updates if configured
|
||||
if (
|
||||
hasattr(self.args, "sglang_pause_generation_during_update")
|
||||
and self.args.sglang_pause_generation_during_update
|
||||
):
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
self.sglang_client.pause_generation()
|
||||
elif self.sglang_mode == "colocate":
|
||||
self.sglang_engine.pause_generation()
|
||||
|
||||
# Use the advanced weight updater
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
self.sglang_weight_updater.update_model_weights(deepspeed_plugin=deepspeed_plugin)
|
||||
|
||||
# Resume generation if paused
|
||||
if (
|
||||
hasattr(self.args, "sglang_pause_generation_during_update")
|
||||
and self.args.sglang_pause_generation_during_update
|
||||
):
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
self.sglang_client.continue_generation()
|
||||
elif self.sglang_mode == "colocate":
|
||||
self.sglang_engine.continue_generation()
|
||||
|
||||
return # Exit early - the advanced updater handles everything
|
||||
|
||||
# Fallback to original implementation if bucketed updates are disabled
|
||||
logger.warning(
|
||||
"Using legacy weight update method. Consider enabling 'use_sglang_bucketed_updates' for better performance."
|
||||
)
|
||||
|
||||
# For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
if zero_stage_3:
|
||||
import deepspeed
|
||||
|
||||
gather_if_zero3 = deepspeed.zero.GatheredParameters
|
||||
else:
|
||||
gather_if_zero3 = nullcontext
|
||||
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
# Update SGLang weights while parameters are gathered
|
||||
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
|
||||
# Update SGLang weights while parameters are gathered
|
||||
# For PEFT with FSDP we need to use the memory efficient post-order traversal
|
||||
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
|
||||
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
|
||||
if fsdp_version == 1:
|
||||
self._sync_fsdp1_params_to_sglang(self.model)
|
||||
elif fsdp_version == 2:
|
||||
self._sync_fsdp2_params_to_sglang(self.model)
|
||||
else:
|
||||
# DeepSpeed ZeRO-3 with PEFT - Batch all parameter updates
|
||||
names, dtypes, shapes = [], [], []
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = self._fix_param_name_to_sglang(name, extra_prefixes=["modules_to_save.default."])
|
||||
names.append(name)
|
||||
dtypes.append(str(param.data.dtype))
|
||||
shapes.append(list(param.data.shape))
|
||||
|
||||
# Single batched update for all parameters
|
||||
if names: # Only update if we have parameters
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
# Use SGLang client's batch update API
|
||||
url = f"{self.sglang_client.base_url}/update_weights/"
|
||||
response = self.sglang_client.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": False, # Don't flush here, do it at the end
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"SGLang weight update failed: {response.status_code}, {response.text}"
|
||||
)
|
||||
elif self.sglang_mode == "colocate":
|
||||
# Single NCCL operation for all parameters
|
||||
self.sglang_engine.update_weights_from_distributed(names, dtypes, shapes, "weight_sync")
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# Regular model parameters - no PEFT
|
||||
if self.is_fsdp_enabled:
|
||||
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
|
||||
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
|
||||
if fsdp_version == 1:
|
||||
self._sync_fsdp1_params_to_sglang(self.model)
|
||||
elif fsdp_version == 2:
|
||||
self._sync_fsdp2_params_to_sglang(self.model)
|
||||
else:
|
||||
# Regular model parameters - Batch all parameter updates
|
||||
names, dtypes, shapes = [], [], []
|
||||
params_to_gather = []
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
name = self._fix_param_name_to_sglang(name)
|
||||
names.append(name)
|
||||
dtypes.append(str(param.dtype))
|
||||
shapes.append(list(param.shape))
|
||||
params_to_gather.append(param)
|
||||
|
||||
# Gather all parameters at once for DeepSpeed ZeRO-3
|
||||
with gather_if_zero3(params_to_gather):
|
||||
# Single batched update for all parameters
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
# Use SGLang client's batch update API
|
||||
url = f"{self.sglang_client.base_url}/update_weights/"
|
||||
response = self.sglang_client.session.post(
|
||||
url,
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": "weight_sync",
|
||||
"flush_cache": False, # Don't flush here, do it at the end
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"SGLang weight update failed: {response.status_code}, {response.text}")
|
||||
elif self.sglang_mode == "colocate":
|
||||
# Single NCCL operation for all parameters
|
||||
self.sglang_engine.update_weights_from_distributed(names, dtypes, shapes, "weight_sync")
|
||||
|
||||
# Reset cache on SGLang
|
||||
if self.sglang_mode == "server" and self.accelerator.is_main_process:
|
||||
self.sglang_client.flush_cache()
|
||||
elif self.sglang_mode == "colocate":
|
||||
self.sglang_engine.reset_prefix_cache()
|
||||
|
||||
@profiling_decorator
|
||||
def _prepare_inputs(
|
||||
self, generation_batch: dict[str, Union[torch.Tensor, Any]]
|
||||
@ -1524,6 +1885,95 @@ class GRPOTrainer(Trainer):
|
||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||
completion_ids = completion_ids[tp_slice]
|
||||
|
||||
elif self.use_sglang:
|
||||
# First, update the SGLang weights if needed
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
self._move_model_to_sglang()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Generate completions using SGLang
|
||||
if self.sglang_mode == "server":
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if has_images:
|
||||
all_images = gather_object(images)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
# Prepare sampling parameters for SGLang
|
||||
sampling_params = {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"max_new_tokens": self.max_completion_length,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
}
|
||||
# Add any additional generation kwargs
|
||||
if self.args.generation_kwargs is not None:
|
||||
sampling_params.update(self.args.generation_kwargs)
|
||||
|
||||
with profiling_context(self, "SGLang.generate"):
|
||||
completion_ids = self.sglang_client.generate(
|
||||
prompts=all_prompts_text,
|
||||
images=all_images if has_images else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
# Broadcast completion_ids to all processes
|
||||
completion_ids = broadcast_object_list([completion_ids])[0]
|
||||
else:
|
||||
completion_ids = broadcast_object_list([None])[0]
|
||||
|
||||
# Distribute completions back to processes
|
||||
num_prompts_per_process = [len(p) for p in gather_object(len(prompts_text))]
|
||||
start_idx = sum(num_prompts_per_process[: self.accelerator.process_index])
|
||||
end_idx = start_idx + num_prompts_per_process[self.accelerator.process_index]
|
||||
completion_ids = completion_ids[start_idx:end_idx]
|
||||
|
||||
elif self.sglang_mode == "colocate":
|
||||
# For colocate mode, each process generates its own completions
|
||||
sampling_params = {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"max_new_tokens": self.max_completion_length,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
}
|
||||
# Add any additional generation kwargs
|
||||
if self.args.generation_kwargs is not None:
|
||||
sampling_params.update(self.args.generation_kwargs)
|
||||
|
||||
if self.sglang_tensor_parallel_size > 1:
|
||||
# Gather prompts from all ranks in the TP group
|
||||
orig_size = len(prompts_text)
|
||||
gathered_prompts = [None for _ in range(self.sglang_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.sglang_tp_group)
|
||||
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
|
||||
|
||||
if has_images:
|
||||
gathered_images = [None for _ in range(self.sglang_tensor_parallel_size)]
|
||||
torch.distributed.all_gather_object(gathered_images, images, group=self.sglang_tp_group)
|
||||
all_images = [img for sublist in gathered_images for img in sublist]
|
||||
else:
|
||||
all_images = None
|
||||
else:
|
||||
all_prompts_text = prompts_text
|
||||
all_images = images if has_images else None
|
||||
|
||||
with profiling_context(self, "SGLang.generate"):
|
||||
# Use SGLang engine to generate completions
|
||||
completion_ids = self.sglang_engine.generate(
|
||||
prompts=all_prompts_text,
|
||||
sampling_params=sampling_params,
|
||||
images=all_images,
|
||||
)
|
||||
|
||||
if self.sglang_tensor_parallel_size > 1:
|
||||
# Slice completions for this rank within its TP group
|
||||
local_rank_in_group = torch.distributed.get_rank(group=self.sglang_tp_group)
|
||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||
completion_ids = completion_ids[tp_slice]
|
||||
|
||||
# Pad the completions, and concatenate them with the prompts
|
||||
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
|
||||
|
Reference in New Issue
Block a user