[worker] refactor: Add kwargs to checkpoint related functions in BaseEngine and its subclasses (#3662)

### What does this PR do?

Add `**kwargs` to the checkpoint APIs of `BaseEngine` (and thread them
through `FSDPEngine`/`MegatronEngine`) to allow engines and pluggable
checkpoint backends to accept implementation-specific options without
changing the common interface. This enables extension when users
subclass `BaseEngine` or integrate internal engines, while preserving
backward compatibility—existing calls remain unchanged and extra keys
are simply ignored unless a subclass consumes them.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Signed-off-by: Hongpeng Guo <hg5@illinois.edu>
Co-authored-by: wuxibin <wuxibin@bytedance.com>
This commit is contained in:
Hongpeng Guo
2025-10-08 23:56:22 -07:00
committed by GitHub
parent 54fed7fec7
commit e56e3df071
4 changed files with 39 additions and 10 deletions

View File

@ -149,7 +149,14 @@ class BaseEngine:
"""
raise NotImplementedError
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
"""
Save model, optimizer, and scheduler states to a checkpoint.
@ -158,10 +165,13 @@ class BaseEngine:
hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
**kwargs: Arbitrary keyword arguments.
"""
raise NotImplementedError
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
) -> None:
"""
Load model, optimizer, and scheduler states from a checkpoint.
@ -169,6 +179,7 @@ class BaseEngine:
local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading.
**kwargs: Arbitrary keyword arguments.
"""
raise NotImplementedError

View File

@ -20,7 +20,7 @@ import logging
import os
import warnings
from contextlib import nullcontext
from typing import Callable
from typing import Callable, Optional
import torch
import torch.distributed
@ -562,7 +562,14 @@ class FSDPEngine(BaseEngine):
else:
raise ValueError(f"Invalid device type: {device}")
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
"""
Save FSDP checkpoint, handling parameter offload as needed.
"""
@ -577,7 +584,9 @@ class FSDPEngine(BaseEngine):
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
) -> None:
"""
Load FSDP checkpoint, restoring parameters and optimizer state.
"""

View File

@ -15,7 +15,7 @@
import logging
import os
from functools import partial
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterator, Optional
import torch
import torch.distributed
@ -331,7 +331,14 @@ class MegatronEngine(BaseEngine):
def get_data_parallel_group(self):
return mpu.get_data_parallel_group()
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
"""
Save model, optimizer, and scheduler states to a checkpoint.
@ -350,7 +357,9 @@ class MegatronEngine(BaseEngine):
if self._is_offload_param:
offload_megatron_model_to_cpu(self.module)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
) -> None:
"""
Load model, optimizer, and scheduler states from a checkpoint.

View File

@ -181,8 +181,8 @@ class AsyncRolloutRequest(BaseModel):
# Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an
# error for this case in the future.
# Ensure batch_data_id exists with default value if not provided
if 'batch_data_id' not in values:
values['batch_data_id'] = cls.model_fields['batch_data_id'].default
if "batch_data_id" not in values:
values["batch_data_id"] = cls.model_fields["batch_data_id"].default
logger.warning(
f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} "
f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools."