[rollout] feat: use rollout worker in MegatronWorker (#3111)

### What does this PR do?

- As title

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Chi Zhang
2025-08-19 15:07:52 +08:00
committed by GitHub
parent 43cb93c8d1
commit 8494135e5c
4 changed files with 87 additions and 119 deletions

View File

@ -10,6 +10,8 @@ math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"
offload=True
python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \
@ -24,15 +26,19 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \

View File

@ -388,30 +388,32 @@ class RayPPOTrainer:
config = self.config
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
f"({minimal_bsz})"
)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
f"({minimal_bsz})"
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".

View File

@ -15,6 +15,7 @@
The main entry point to run the PPO algorithm
"""
import copy
import datetime
import logging
import os
@ -25,7 +26,7 @@ import psutil
import torch
import torch.distributed
from codetiming import Timer
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, OmegaConf, open_dict
try:
from mindspeed.megatron_adaptor import repatch
@ -61,9 +62,10 @@ from verl.utils.profiler import (
)
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.config import McoreCriticConfig, RolloutConfig
from verl.workers.config import HFModelConfig, McoreCriticConfig, RolloutConfig
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel
from verl.workers.rollout.rollout_worker import RolloutWorker
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@ -381,47 +383,49 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
"gate_proj_layer_name": "linear_fc1.",
}
rollout_name = self.config.rollout.name
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
if self.config.rollout.name == "vllm":
from torch.distributed.device_mesh import init_device_mesh
# (vermouth1992). self.config.model in megatron differs from that of fsdp in the override_config.
# To workaround this we deepcopy self.config.model and make them compatible
omega_model_config = copy.deepcopy(self.config.model)
with open_dict(omega_model_config):
override_config = omega_model_config.override_config.pop("model_config")
omega_model_config.override_config = override_config
model_config: HFModelConfig = omega_conf_to_dataclass(omega_model_config, dataclass_type=HFModelConfig)
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
# build rollout worker inside hybrid engine
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
rollout_worker = RolloutWorker(config=rollout_config, model_config=model_config)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
if self.config.rollout.name == "vllm":
# perform weight resharding between actor and rollout
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
# we will reorganize their weight format when resharding from actor to rollout.
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
log_gpu_memory_usage("Before building vllm rollout", logger=None)
local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(
model_path=local_path,
config=rollout_config,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage("After building vllm rollout", logger=logger)
# perform weight resharding between actor and rollout
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
sharding_manager = MegatronVLLMShardingManager(
inference_engine=rollout.inference_engine,
inference_engine=rollout_worker.rollout.inference_engine,
model_config=self.actor_model_config,
transformer_config=self.tf_config,
rollout_config=self.config.rollout,
@ -434,14 +438,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
elif self.config.rollout.name == "sglang":
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's
# model_runner would check CUDA device capability.
# However, due to verl's setting, the main process of ray can not find any CUDA device, which would
@ -451,38 +448,10 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
"cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp")
)
is_collect = rollout_device_mesh["tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None)
rollout = SGLangRollout(
actor_module=local_path,
config=rollout_config,
processing_class=self.processor if self.processor is not None else self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None)
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
sharding_manager = MegatronSGLangShardingManager(
actor_module=self.actor.actor_module,
inference_engine=rollout._engine,
inference_engine=rollout_worker.rollout._engine,
model_config=self.actor_model_config,
rollout_config=self.config.rollout,
transformer_config=self.tf_config,
@ -496,7 +465,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
else:
raise NotImplementedError("Only vllmRollout is supported with Megatron now")
print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}")
return rollout, sharding_manager
return rollout_worker, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
@ -778,7 +747,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
def _build_rollout(self, trust_remote_code=False):
rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
rollout_worker, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
# NOTE: rollout is not actually initialized here, it's deferred
# to be initialized by AsyncvLLMServer.
@ -788,20 +757,14 @@ class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size
# used for sleep/wake_up
rollout.sharding_manager = rollout_sharding_manager
rollout_worker.rollout.sharding_manager = rollout_sharding_manager
return rollout, rollout_sharding_manager
return rollout_worker, rollout_sharding_manager
# ============================ vLLM related ============================
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def execute_method(self, method: str | bytes, *args, **kwargs):
"""Called by ExternalRayDistributedExecutor collective_rpc."""
if self.vllm_tp_rank == 0 and method != "execute_model":
print(
f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: "
f"{method if isinstance(method, str) else 'Callable'}"
)
return self.rollout.execute_method(method, *args, **kwargs)
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
@ -828,15 +791,12 @@ class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
async def wake_up(self):
if self.config.rollout.free_cache_engine:
await self.rollout.wake_up()
# return something to block the caller
await self.rollout.wake_up()
return True
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
async def sleep(self):
if self.config.rollout.free_cache_engine:
await self.rollout.sleep()
await self.rollout.sleep()
# return something to block the caller
return True

View File

@ -101,7 +101,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
self.offload_param = offload_param
if self.device_mesh is not None:
self.infer_tp_size = self.device_mesh["tp"].mesh.size()[0]
self.infer_tp_size = self.device_mesh["infer_tp"].mesh.size()[0]
else:
self.infer_tp_size = self.inference_engine._tp_size
@ -141,7 +141,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
"""
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.resume_memory_occupation()
named_tensors = params
@ -150,15 +150,15 @@ class MegatronSGLangShardingManager(BaseShardingManager):
await sgl_update_weights(
engine=self.inference_engine,
params_batch=params_batch,
device_mesh_key="tp",
device_mesh_key="infer_tp",
device_mesh=self.device_mesh,
)
if self.device_mesh["tp"].get_local_rank() == 0:
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache()
async def release_memory(self):
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.release_memory_occupation()
@GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger)
@ -206,7 +206,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1:
return data
all_gather_data_proto(data, self.device_mesh["tp"].get_group())
all_gather_data_proto(data, self.device_mesh["infer_tp"].get_group())
return data
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
@ -214,4 +214,4 @@ class MegatronSGLangShardingManager(BaseShardingManager):
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1:
return data
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()]
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["infer_tp"].get_local_rank()]