[worker] refactor: Add kwargs to checkpoint related functions in BaseEngine and its subclasses (#3662)

### What does this PR do?

Add `**kwargs` to the checkpoint APIs of `BaseEngine` (and thread them
through `FSDPEngine`/`MegatronEngine`) to allow engines and pluggable
checkpoint backends to accept implementation-specific options without
changing the common interface. This enables extension when users
subclass `BaseEngine` or integrate internal engines, while preserving
backward compatibility—existing calls remain unchanged and extra keys
are simply ignored unless a subclass consumes them.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Signed-off-by: Hongpeng Guo <hg5@illinois.edu>
Co-authored-by: wuxibin <wuxibin@bytedance.com>
This commit is contained in:
Hongpeng Guo
2025-10-08 23:56:22 -07:00
committed by GitHub
parent 54fed7fec7
commit e56e3df071
4 changed files with 39 additions and 10 deletions

View File

@ -149,7 +149,14 @@ class BaseEngine:
""" """
raise NotImplementedError raise NotImplementedError
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
""" """
Save model, optimizer, and scheduler states to a checkpoint. Save model, optimizer, and scheduler states to a checkpoint.
@ -158,10 +165,13 @@ class BaseEngine:
hdfs_path: Optional HDFS path to copy checkpoint. hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming. global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain. max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
**kwargs: Arbitrary keyword arguments.
""" """
raise NotImplementedError raise NotImplementedError
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
) -> None:
""" """
Load model, optimizer, and scheduler states from a checkpoint. Load model, optimizer, and scheduler states from a checkpoint.
@ -169,6 +179,7 @@ class BaseEngine:
local_path: Local filesystem path of the checkpoint. local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored. hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading. del_local_after_load: Whether to delete local copy after loading.
**kwargs: Arbitrary keyword arguments.
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -20,7 +20,7 @@ import logging
import os import os
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Callable from typing import Callable, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -562,7 +562,14 @@ class FSDPEngine(BaseEngine):
else: else:
raise ValueError(f"Invalid device type: {device}") raise ValueError(f"Invalid device type: {device}")
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
""" """
Save FSDP checkpoint, handling parameter offload as needed. Save FSDP checkpoint, handling parameter offload as needed.
""" """
@ -577,7 +584,9 @@ class FSDPEngine(BaseEngine):
if self._is_offload_param: if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module) offload_fsdp_model_to_cpu(self.module)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
) -> None:
""" """
Load FSDP checkpoint, restoring parameters and optimizer state. Load FSDP checkpoint, restoring parameters and optimizer state.
""" """

View File

@ -15,7 +15,7 @@
import logging import logging
import os import os
from functools import partial from functools import partial
from typing import Any, Callable, Iterator from typing import Any, Callable, Iterator, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -331,7 +331,14 @@ class MegatronEngine(BaseEngine):
def get_data_parallel_group(self): def get_data_parallel_group(self):
return mpu.get_data_parallel_group() return mpu.get_data_parallel_group()
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
""" """
Save model, optimizer, and scheduler states to a checkpoint. Save model, optimizer, and scheduler states to a checkpoint.
@ -350,7 +357,9 @@ class MegatronEngine(BaseEngine):
if self._is_offload_param: if self._is_offload_param:
offload_megatron_model_to_cpu(self.module) offload_megatron_model_to_cpu(self.module)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
) -> None:
""" """
Load model, optimizer, and scheduler states from a checkpoint. Load model, optimizer, and scheduler states from a checkpoint.

View File

@ -181,8 +181,8 @@ class AsyncRolloutRequest(BaseModel):
# Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an
# error for this case in the future. # error for this case in the future.
# Ensure batch_data_id exists with default value if not provided # Ensure batch_data_id exists with default value if not provided
if 'batch_data_id' not in values: if "batch_data_id" not in values:
values['batch_data_id'] = cls.model_fields['batch_data_id'].default values["batch_data_id"] = cls.model_fields["batch_data_id"].default
logger.warning( logger.warning(
f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} " f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} "
f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools." f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools."