[sglang] fix: add memory saver support to sglang rollout to avoid OOMs (#756)

as title

---------

Co-authored-by: ocss884 <ocss.lin@gmail.com>
This commit is contained in:
Xiang Long
2025-03-30 23:36:16 +08:00
committed by haibin.lin
parent 52bec83183
commit b70981bdb9
6 changed files with 33 additions and 8 deletions

View File

@ -39,7 +39,7 @@ jobs:
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.3.post3
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.4.post3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

21
requirements_sglang.txt Normal file
View File

@ -0,0 +1,21 @@
# requirements.txt records the full set of dependencies for development
accelerate
codetiming
datasets
dill
flash-attn
hydra-core
numpy
pandas
peft
pyarrow>=15.0.0
pybind11
pylatexenc
ray[default]>=2.10
tensordict<=0.6.2
torchdata
torchvision
transformers
wandb
sglang[all]==0.4.4.post3
torch-memory-saver>=0.0.5

View File

@ -46,7 +46,11 @@ GEO_REQUIRES = ['mathruler']
GPU_REQUIRES = ['liger-kernel', 'flash-attn']
MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2']
SGLANG_REQUIRES = ['tensordict<=0.6.2', 'sglang[all]==0.4.4']
SGLANG_REQUIRES = [
'tensordict<=0.6.2',
'sglang[all]==0.4.4.post3',
'torch-memory-saver>=0.0.5'
]
extras_require = {
'test': TEST_REQUIRES,

View File

@ -9,16 +9,13 @@ from typing import Optional
import torch
import torch.distributed
import vllm.distributed.parallel_state as ps
from vllm.distributed.parallel_state import (
import sglang.srt.distributed.parallel_state as ps
from sglang.srt.distributed.parallel_state import (
get_pp_group,
get_world_group,
init_distributed_environment,
init_model_parallel_group,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.

View File

@ -146,6 +146,7 @@ class SGLangRollout(BaseRollout):
dtype=config.dtype,
mem_fraction_static=config.gpu_memory_utilization,
device_mesh_cpu=device_mesh_cpu["tp"],
enable_memory_saver=True,
base_gpu_id=0,
gpu_id_step=1,
# NOTE(Chenyang): if you want to debug the sglang engine

View File

@ -80,11 +80,13 @@ class FSDPSGLangShardingManager(BaseShardingManager):
self.gen_random_states = None
def __enter__(self):
torch.cuda.empty_cache()
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory
load_format = None if self.full_params else 'dtensor'
self.inference_engine.resume_memory_occupation()
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
@ -106,7 +108,7 @@ class FSDPSGLangShardingManager(BaseShardingManager):
def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger)
self.inference_engine.release_memory_occupation
self.inference_engine.release_memory_occupation()
log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger)
# self.module.to('cuda')