mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[trainer] fix: only load memory in micro batch (#2908)
### What does this PR do? In update_actor, it load the whole bath into GPU memory, actually only the micro batch is necessary. It is a regression from https://github.com/volcengine/verl/pull/2477 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pulls?q=is%3Apr+is%3Aopen+micro+batch - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test <img width="700" height="325" alt="截屏2025-08-05 下午1 01 53" src="https://github.com/user-attachments/assets/31dc4fea-8cb0-4f51-8ed2-f93d90a94040" /> <img width="1359" height="607" alt="截屏2025-08-05 下午12 45 50" src="https://github.com/user-attachments/assets/747636e6-b919-4eca-a3eb-5baf3722b5fc" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Chi Zhang <zhangchi.usc1992@bytedance.com>
This commit is contained in:
@ -27,7 +27,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
|
||||
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
|
||||
from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available
|
||||
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
|
||||
from verl.utils.profiler import GPUMemoryLogger
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
@ -401,6 +401,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
self.actor_optimizer.zero_grad()
|
||||
|
||||
for micro_batch in micro_batches:
|
||||
micro_batch = micro_batch.to(get_device_id())
|
||||
micro_batch_metrics = {}
|
||||
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
|
||||
response_mask = model_inputs["response_mask"]
|
||||
|
@ -699,9 +699,6 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
|
||||
@DistProfiler.annotate(color="red", role="actor_update")
|
||||
def update_actor(self, data: DataProto):
|
||||
# Support all hardwares
|
||||
data = data.to(get_device_id())
|
||||
|
||||
assert self._is_actor
|
||||
if self._is_offload_param:
|
||||
load_fsdp_model_to_gpu(self.actor_module_fsdp)
|
||||
@ -709,6 +706,8 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())
|
||||
|
||||
with self.ulysses_sharding_manager:
|
||||
data = data.to("cpu") # data will to device with each micro batch on actor.update_policy
|
||||
|
||||
# perform training
|
||||
with Timer(name="update_policy", logger=None) as timer:
|
||||
metrics = self.actor.update_policy(data=data)
|
||||
|
Reference in New Issue
Block a user