mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[worker] refactor: Add kwargs
to checkpoint related functions in BaseEngine
and its subclasses (#3662)
### What does this PR do? Add `**kwargs` to the checkpoint APIs of `BaseEngine` (and thread them through `FSDPEngine`/`MegatronEngine`) to allow engines and pluggable checkpoint backends to accept implementation-specific options without changing the common interface. This enables extension when users subclass `BaseEngine` or integrate internal engines, while preserving backward compatibility—existing calls remain unchanged and extra keys are simply ignored unless a subclass consumes them. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Signed-off-by: Hongpeng Guo <hg5@illinois.edu> Co-authored-by: wuxibin <wuxibin@bytedance.com>
This commit is contained in:
@ -149,7 +149,14 @@ class BaseEngine:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
def save_checkpoint(
|
||||
self,
|
||||
local_path: str,
|
||||
hdfs_path: Optional[str] = None,
|
||||
global_step: int = 0,
|
||||
max_ckpt_to_keep: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Save model, optimizer, and scheduler states to a checkpoint.
|
||||
|
||||
@ -158,10 +165,13 @@ class BaseEngine:
|
||||
hdfs_path: Optional HDFS path to copy checkpoint.
|
||||
global_step: Integer training step number for naming.
|
||||
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
def load_checkpoint(
|
||||
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Load model, optimizer, and scheduler states from a checkpoint.
|
||||
|
||||
@ -169,6 +179,7 @@ class BaseEngine:
|
||||
local_path: Local filesystem path of the checkpoint.
|
||||
hdfs_path: Optional HDFS path where checkpoint is stored.
|
||||
del_local_after_load: Whether to delete local copy after loading.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -20,7 +20,7 @@ import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -562,7 +562,14 @@ class FSDPEngine(BaseEngine):
|
||||
else:
|
||||
raise ValueError(f"Invalid device type: {device}")
|
||||
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
def save_checkpoint(
|
||||
self,
|
||||
local_path: str,
|
||||
hdfs_path: Optional[str] = None,
|
||||
global_step: int = 0,
|
||||
max_ckpt_to_keep: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Save FSDP checkpoint, handling parameter offload as needed.
|
||||
"""
|
||||
@ -577,7 +584,9 @@ class FSDPEngine(BaseEngine):
|
||||
if self._is_offload_param:
|
||||
offload_fsdp_model_to_cpu(self.module)
|
||||
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
def load_checkpoint(
|
||||
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Load FSDP checkpoint, restoring parameters and optimizer state.
|
||||
"""
|
||||
|
@ -15,7 +15,7 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterator
|
||||
from typing import Any, Callable, Iterator, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -331,7 +331,14 @@ class MegatronEngine(BaseEngine):
|
||||
def get_data_parallel_group(self):
|
||||
return mpu.get_data_parallel_group()
|
||||
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
def save_checkpoint(
|
||||
self,
|
||||
local_path: str,
|
||||
hdfs_path: Optional[str] = None,
|
||||
global_step: int = 0,
|
||||
max_ckpt_to_keep: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Save model, optimizer, and scheduler states to a checkpoint.
|
||||
|
||||
@ -350,7 +357,9 @@ class MegatronEngine(BaseEngine):
|
||||
if self._is_offload_param:
|
||||
offload_megatron_model_to_cpu(self.module)
|
||||
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
def load_checkpoint(
|
||||
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Load model, optimizer, and scheduler states from a checkpoint.
|
||||
|
||||
|
@ -181,8 +181,8 @@ class AsyncRolloutRequest(BaseModel):
|
||||
# Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an
|
||||
# error for this case in the future.
|
||||
# Ensure batch_data_id exists with default value if not provided
|
||||
if 'batch_data_id' not in values:
|
||||
values['batch_data_id'] = cls.model_fields['batch_data_id'].default
|
||||
if "batch_data_id" not in values:
|
||||
values["batch_data_id"] = cls.model_fields["batch_data_id"].default
|
||||
logger.warning(
|
||||
f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} "
|
||||
f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools."
|
||||
|
Reference in New Issue
Block a user