mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
### What does this PR do? - Add open math reasoning recipe using sft trainer with model engine - Support setting none to val dataset in sft trainer - Fix main_eval - Using aiohttp for main_generation_server to avoid hang in AsyncOpenAI ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
239 lines
8.6 KiB
Python
239 lines
8.6 KiB
Python
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import random
|
|
import shutil
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed
|
|
from omegaconf import DictConfig
|
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
|
|
from verl.trainer.config import CheckpointConfig
|
|
from verl.utils.device import get_device_name, get_torch_device
|
|
|
|
|
|
class BaseCheckpointManager:
|
|
"""
|
|
A checkpoint manager that saves and loads the following states in a SPMD way:
|
|
- model
|
|
- optimizer
|
|
- lr_scheduler
|
|
- extra_states
|
|
|
|
We save
|
|
- sharded model states and optimizer states
|
|
- full lr_scheduler states
|
|
- huggingface tokenizer and config for ckpt merge
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
optimizer: torch.optim.Optimizer,
|
|
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
|
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
|
|
checkpoint_config: DictConfig | CheckpointConfig = None,
|
|
):
|
|
self.checkpoint_config = checkpoint_config
|
|
checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None
|
|
checkpoint_save_contents = checkpoint_config.get("save_contents", None) if checkpoint_config else None
|
|
if checkpoint_load_contents is None:
|
|
checkpoint_load_contents = ["model", "optimizer", "extra"]
|
|
if checkpoint_save_contents is None:
|
|
checkpoint_save_contents = ["model", "optimizer", "extra"]
|
|
self.previous_global_step = None
|
|
self.previous_saved_paths = []
|
|
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.lr_scheduler = lr_scheduler
|
|
self.processing_class = processing_class
|
|
self.checkpoint_load_contents = checkpoint_load_contents
|
|
self.checkpoint_save_contents = checkpoint_save_contents
|
|
|
|
self.rank = torch.distributed.get_rank()
|
|
self.world_size = torch.distributed.get_world_size()
|
|
|
|
@property
|
|
def should_save_model(self) -> bool:
|
|
"""
|
|
Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved.
|
|
"""
|
|
return "model" in self.checkpoint_save_contents
|
|
|
|
@property
|
|
def should_save_optimizer(self) -> bool:
|
|
"""
|
|
Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved.
|
|
"""
|
|
return "optimizer" in self.checkpoint_save_contents
|
|
|
|
@property
|
|
def should_save_extra(self) -> bool:
|
|
"""
|
|
Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved.
|
|
"""
|
|
return "extra" in self.checkpoint_save_contents
|
|
|
|
@property
|
|
def should_save_hf_model(self) -> bool:
|
|
"""
|
|
Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf
|
|
model and saved.
|
|
"""
|
|
return "hf_model" in self.checkpoint_save_contents
|
|
|
|
@property
|
|
def should_load_model(self) -> bool:
|
|
"""
|
|
Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.
|
|
"""
|
|
return "model" in self.checkpoint_load_contents
|
|
|
|
@property
|
|
def should_load_optimizer(self) -> bool:
|
|
"""
|
|
Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.
|
|
"""
|
|
return "optimizer" in self.checkpoint_load_contents
|
|
|
|
@property
|
|
def should_load_extra(self) -> bool:
|
|
"""
|
|
Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded.
|
|
"""
|
|
return "extra" in self.checkpoint_load_contents
|
|
|
|
def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):
|
|
raise NotImplementedError
|
|
|
|
def save_checkpoint(
|
|
self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None
|
|
):
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def checkpath(local_path: str, hdfs_path: str):
|
|
assert local_path is not None or hdfs_path is not None, "local_path and hdfs_path cannot be both None"
|
|
return local_path is not None, local_path if local_path is not None else hdfs_path
|
|
|
|
def remove_previous_save_local_path(self, path):
|
|
if isinstance(path, str):
|
|
path = [path]
|
|
for p in path:
|
|
abs_path = os.path.abspath(p)
|
|
print(f"Checkpoint manager remove previous save local path: {abs_path}")
|
|
if not os.path.exists(abs_path):
|
|
continue
|
|
shutil.rmtree(abs_path, ignore_errors=True)
|
|
|
|
@staticmethod
|
|
def get_rng_state():
|
|
rng_state = {
|
|
"cpu": torch.get_rng_state(),
|
|
"numpy": np.random.get_state(),
|
|
"random": random.getstate(),
|
|
}
|
|
|
|
if get_device_name() != "cpu":
|
|
rng_state[get_device_name()] = get_torch_device().get_rng_state()
|
|
|
|
return rng_state
|
|
|
|
@staticmethod
|
|
def load_rng_state(rng_state):
|
|
torch.set_rng_state(rng_state["cpu"])
|
|
np.random.set_state(rng_state["numpy"])
|
|
random.setstate(rng_state["random"])
|
|
|
|
if get_device_name() != "cpu":
|
|
get_torch_device().set_rng_state(rng_state[get_device_name()])
|
|
|
|
|
|
def find_latest_ckpt_path(path, directory_format="global_step_{}"):
|
|
"""
|
|
Return the most recent checkpoint directory based on a tracker file.
|
|
|
|
Args:
|
|
path (str): Base directory containing the checkpoint tracker.
|
|
directory_format (str): Template for checkpoint subfolders with one
|
|
placeholder for the iteration number (default "global_step_{}").
|
|
|
|
Returns:
|
|
str or None: Full path to the latest checkpoint directory, or
|
|
None if the tracker or checkpoint folder is missing.
|
|
"""
|
|
if path is None:
|
|
return None
|
|
|
|
tracker_file = get_checkpoint_tracker_filename(path)
|
|
if not os.path.exists(tracker_file):
|
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
|
print(f"Checkpoint tracker file does not exist: {tracker_file}")
|
|
return None
|
|
|
|
with open(tracker_file, "rb") as f:
|
|
iteration = int(f.read().decode())
|
|
ckpt_path = os.path.join(path, directory_format.format(iteration))
|
|
if not os.path.exists(ckpt_path):
|
|
print("Checkpoint does not exist: %s", ckpt_path)
|
|
return None
|
|
|
|
print("Found checkpoint: %s", ckpt_path)
|
|
return ckpt_path
|
|
|
|
|
|
def get_checkpoint_tracker_filename(root_path: str):
|
|
"""
|
|
Tracker file rescords the latest chckpoint during training to restart from.
|
|
"""
|
|
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
|
|
|
|
|
|
def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool:
|
|
"""
|
|
Determine if checkpoint should be saved based on capacity esi expiration.
|
|
|
|
Args:
|
|
max_steps_duration: Max estimated time (seconds) required to complete one training step
|
|
save_ckpt_duration: Estimated time (seconds) required to save checkpoint (default: 60)
|
|
redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0)
|
|
"""
|
|
exp_ts_mlp = os.getenv("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # vemlp
|
|
exp_ts_aws = os.getenv("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # aws
|
|
if exp_ts_mlp:
|
|
try:
|
|
import time
|
|
|
|
remaining = float(exp_ts_mlp) - time.time()
|
|
except ValueError:
|
|
return False
|
|
return (
|
|
remaining > 0
|
|
and max_steps_duration > 0
|
|
and remaining <= save_ckpt_duration + max_steps_duration + redundant_time
|
|
)
|
|
elif exp_ts_aws:
|
|
from datetime import datetime, timedelta
|
|
|
|
expiration_time = datetime.fromtimestamp(int(exp_ts_aws))
|
|
time_difference = expiration_time - datetime.now()
|
|
threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60
|
|
return time_difference < timedelta(minutes=threshold_minutes)
|
|
else:
|
|
return False
|