[rollout] feat: use rollout worker in MegatronWorker (#3111)

### What does this PR do?

- As title

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Chi Zhang
2025-08-19 15:07:52 +08:00
committed by GitHub
parent 43cb93c8d1
commit 8494135e5c
4 changed files with 87 additions and 119 deletions

View File

@ -10,6 +10,8 @@ math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']" train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']" test_files="['$gsm8k_test_path', '$math_test_path']"
offload=True
python3 -m verl.trainer.main_ppo --config-path=config \ python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\ --config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \ algorithm.adv_estimator=grpo \
@ -24,15 +26,19 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \ actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.rollout.n=5 \

View File

@ -388,30 +388,32 @@ class RayPPOTrainer:
config = self.config config = self.config
# number of GPUs total # number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness if not config.actor_rollout_ref.actor.use_dynamic_bsz:
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n if config.actor_rollout_ref.actor.strategy == "megatron":
assert real_train_batch_size % minimal_bsz == 0, ( model_parallel_size = (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
f"({minimal_bsz})" * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
) )
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
f"({minimal_bsz})"
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".

View File

@ -15,6 +15,7 @@
The main entry point to run the PPO algorithm The main entry point to run the PPO algorithm
""" """
import copy
import datetime import datetime
import logging import logging
import os import os
@ -25,7 +26,7 @@ import psutil
import torch import torch
import torch.distributed import torch.distributed
from codetiming import Timer from codetiming import Timer
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf, open_dict
try: try:
from mindspeed.megatron_adaptor import repatch from mindspeed.megatron_adaptor import repatch
@ -61,9 +62,10 @@ from verl.utils.profiler import (
) )
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.workers.actor.megatron_actor import MegatronPPOActor from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.config import McoreCriticConfig, RolloutConfig from verl.workers.config import HFModelConfig, McoreCriticConfig, RolloutConfig
from verl.workers.critic.megatron_critic import MegatronPPOCritic from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel
from verl.workers.rollout.rollout_worker import RolloutWorker
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@ -381,47 +383,49 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
"gate_proj_layer_name": "linear_fc1.", "gate_proj_layer_name": "linear_fc1.",
} }
rollout_name = self.config.rollout.name
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
if self.config.rollout.name == "vllm": # (vermouth1992). self.config.model in megatron differs from that of fsdp in the override_config.
from torch.distributed.device_mesh import init_device_mesh # To workaround this we deepcopy self.config.model and make them compatible
omega_model_config = copy.deepcopy(self.config.model)
with open_dict(omega_model_config):
override_config = omega_model_config.override_config.pop("model_config")
omega_model_config.override_config = override_config
model_config: HFModelConfig = omega_conf_to_dataclass(omega_model_config, dataclass_type=HFModelConfig)
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
# build rollout worker inside hybrid engine
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
rollout_worker = RolloutWorker(config=rollout_config, model_config=model_config)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
if self.config.rollout.name == "vllm":
# perform weight resharding between actor and rollout
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
# we will reorganize their weight format when resharding from actor to rollout.
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
log_gpu_memory_usage("Before building vllm rollout", logger=None)
local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(
model_path=local_path,
config=rollout_config,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage("After building vllm rollout", logger=logger)
# perform weight resharding between actor and rollout
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
sharding_manager = MegatronVLLMShardingManager( sharding_manager = MegatronVLLMShardingManager(
inference_engine=rollout.inference_engine, inference_engine=rollout_worker.rollout.inference_engine,
model_config=self.actor_model_config, model_config=self.actor_model_config,
transformer_config=self.tf_config, transformer_config=self.tf_config,
rollout_config=self.config.rollout, rollout_config=self.config.rollout,
@ -434,14 +438,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
) )
log_gpu_memory_usage("After building sharding manager", logger=logger) log_gpu_memory_usage("After building sharding manager", logger=logger)
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
elif self.config.rollout.name == "sglang": elif self.config.rollout.name == "sglang":
from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's
# model_runner would check CUDA device capability. # model_runner would check CUDA device capability.
# However, due to verl's setting, the main process of ray can not find any CUDA device, which would # However, due to verl's setting, the main process of ray can not find any CUDA device, which would
@ -451,38 +448,10 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
"cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp")
)
is_collect = rollout_device_mesh["tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
local_path = copy_to_local(self.config.model.path)
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None)
rollout = SGLangRollout(
actor_module=local_path,
config=rollout_config,
processing_class=self.processor if self.processor is not None else self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None)
from verl.models.mcore import get_mcore_weight_converter
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
sharding_manager = MegatronSGLangShardingManager( sharding_manager = MegatronSGLangShardingManager(
actor_module=self.actor.actor_module, actor_module=self.actor.actor_module,
inference_engine=rollout._engine, inference_engine=rollout_worker.rollout._engine,
model_config=self.actor_model_config, model_config=self.actor_model_config,
rollout_config=self.config.rollout, rollout_config=self.config.rollout,
transformer_config=self.tf_config, transformer_config=self.tf_config,
@ -496,7 +465,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
else: else:
raise NotImplementedError("Only vllmRollout is supported with Megatron now") raise NotImplementedError("Only vllmRollout is supported with Megatron now")
print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}") print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}")
return rollout, sharding_manager return rollout_worker, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL) @register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self): def init_model(self):
@ -778,7 +747,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
def _build_rollout(self, trust_remote_code=False): def _build_rollout(self, trust_remote_code=False):
rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) rollout_worker, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
# NOTE: rollout is not actually initialized here, it's deferred # NOTE: rollout is not actually initialized here, it's deferred
# to be initialized by AsyncvLLMServer. # to be initialized by AsyncvLLMServer.
@ -788,20 +757,14 @@ class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size
# used for sleep/wake_up # used for sleep/wake_up
rollout.sharding_manager = rollout_sharding_manager rollout_worker.rollout.sharding_manager = rollout_sharding_manager
return rollout, rollout_sharding_manager return rollout_worker, rollout_sharding_manager
# ============================ vLLM related ============================ # ============================ vLLM related ============================
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def execute_method(self, method: str | bytes, *args, **kwargs): def execute_method(self, method: str | bytes, *args, **kwargs):
"""Called by ExternalRayDistributedExecutor collective_rpc."""
if self.vllm_tp_rank == 0 and method != "execute_model":
print(
f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: "
f"{method if isinstance(method, str) else 'Callable'}"
)
return self.rollout.execute_method(method, *args, **kwargs) return self.rollout.execute_method(method, *args, **kwargs)
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
@ -828,15 +791,12 @@ class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
async def wake_up(self): async def wake_up(self):
if self.config.rollout.free_cache_engine: await self.rollout.wake_up()
await self.rollout.wake_up()
# return something to block the caller
return True return True
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
async def sleep(self): async def sleep(self):
if self.config.rollout.free_cache_engine: await self.rollout.sleep()
await self.rollout.sleep()
# return something to block the caller # return something to block the caller
return True return True

View File

@ -101,7 +101,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
self.offload_param = offload_param self.offload_param = offload_param
if self.device_mesh is not None: if self.device_mesh is not None:
self.infer_tp_size = self.device_mesh["tp"].mesh.size()[0] self.infer_tp_size = self.device_mesh["infer_tp"].mesh.size()[0]
else: else:
self.infer_tp_size = self.inference_engine._tp_size self.infer_tp_size = self.inference_engine._tp_size
@ -141,7 +141,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
""" """
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.resume_memory_occupation() await self.inference_engine.resume_memory_occupation()
named_tensors = params named_tensors = params
@ -150,15 +150,15 @@ class MegatronSGLangShardingManager(BaseShardingManager):
await sgl_update_weights( await sgl_update_weights(
engine=self.inference_engine, engine=self.inference_engine,
params_batch=params_batch, params_batch=params_batch,
device_mesh_key="tp", device_mesh_key="infer_tp",
device_mesh=self.device_mesh, device_mesh=self.device_mesh,
) )
if self.device_mesh["tp"].get_local_rank() == 0: if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache() await self.inference_engine.flush_cache()
async def release_memory(self): async def release_memory(self):
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.release_memory_occupation() await self.inference_engine.release_memory_occupation()
@GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger)
@ -206,7 +206,7 @@ class MegatronSGLangShardingManager(BaseShardingManager):
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1: if self.infer_tp_size == 1:
return data return data
all_gather_data_proto(data, self.device_mesh["tp"].get_group()) all_gather_data_proto(data, self.device_mesh["infer_tp"].get_group())
return data return data
@GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger)
@ -214,4 +214,4 @@ class MegatronSGLangShardingManager(BaseShardingManager):
# DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp
if self.infer_tp_size == 1: if self.infer_tp_size == 1:
return data return data
return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()] return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["infer_tp"].get_local_rank()]