mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[sglang] fix: add memory saver support to sglang rollout to avoid OOMs (#756)
as title --------- Co-authored-by: ocss884 <ocss.lin@gmail.com>
This commit is contained in:
2
.github/workflows/e2e_sglang_gsm8k.yml
vendored
2
.github/workflows/e2e_sglang_gsm8k.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
NO_PROXY: "localhost,127.0.0.1"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
container:
|
||||
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.3.post3
|
||||
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.4.post3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
21
requirements_sglang.txt
Normal file
21
requirements_sglang.txt
Normal file
@ -0,0 +1,21 @@
|
||||
# requirements.txt records the full set of dependencies for development
|
||||
accelerate
|
||||
codetiming
|
||||
datasets
|
||||
dill
|
||||
flash-attn
|
||||
hydra-core
|
||||
numpy
|
||||
pandas
|
||||
peft
|
||||
pyarrow>=15.0.0
|
||||
pybind11
|
||||
pylatexenc
|
||||
ray[default]>=2.10
|
||||
tensordict<=0.6.2
|
||||
torchdata
|
||||
torchvision
|
||||
transformers
|
||||
wandb
|
||||
sglang[all]==0.4.4.post3
|
||||
torch-memory-saver>=0.0.5
|
6
setup.py
6
setup.py
@ -46,7 +46,11 @@ GEO_REQUIRES = ['mathruler']
|
||||
GPU_REQUIRES = ['liger-kernel', 'flash-attn']
|
||||
MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
|
||||
VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2']
|
||||
SGLANG_REQUIRES = ['tensordict<=0.6.2', 'sglang[all]==0.4.4']
|
||||
SGLANG_REQUIRES = [
|
||||
'tensordict<=0.6.2',
|
||||
'sglang[all]==0.4.4.post3',
|
||||
'torch-memory-saver>=0.0.5'
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
'test': TEST_REQUIRES,
|
||||
|
7
verl/third_party/sglang/parallel_state.py
vendored
7
verl/third_party/sglang/parallel_state.py
vendored
@ -9,16 +9,13 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.distributed.parallel_state import (
|
||||
import sglang.srt.distributed.parallel_state as ps
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
get_world_group,
|
||||
init_distributed_environment,
|
||||
init_model_parallel_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
|
||||
- We assume the Megatron tp+dp+pp world is already established before calling this function.
|
||||
|
@ -146,6 +146,7 @@ class SGLangRollout(BaseRollout):
|
||||
dtype=config.dtype,
|
||||
mem_fraction_static=config.gpu_memory_utilization,
|
||||
device_mesh_cpu=device_mesh_cpu["tp"],
|
||||
enable_memory_saver=True,
|
||||
base_gpu_id=0,
|
||||
gpu_id_step=1,
|
||||
# NOTE(Chenyang): if you want to debug the sglang engine
|
||||
|
@ -80,11 +80,13 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
self.gen_random_states = None
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.empty_cache()
|
||||
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
|
||||
params = self.module.state_dict()
|
||||
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
|
||||
# Copy, not share memory
|
||||
load_format = None if self.full_params else 'dtensor'
|
||||
self.inference_engine.resume_memory_occupation()
|
||||
|
||||
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
|
||||
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
|
||||
@ -106,7 +108,7 @@ class FSDPSGLangShardingManager(BaseShardingManager):
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger)
|
||||
self.inference_engine.release_memory_occupation
|
||||
self.inference_engine.release_memory_occupation()
|
||||
log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger)
|
||||
|
||||
# self.module.to('cuda')
|
||||
|
Reference in New Issue
Block a user