mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[trainer] refactor: Training Engine Interface and Development Plan (#1977)
# [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: https://github.com/volcengine/verl/issues/1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
This commit is contained in:
7
.github/workflows/e2e_ppo_trainer.yml
vendored
7
.github/workflows/e2e_ppo_trainer.yml
vendored
@ -101,10 +101,15 @@ jobs:
|
||||
ray stop --force
|
||||
python3 examples/data_preprocess/gsm8k.py
|
||||
# HF sanity
|
||||
- name: Running GSM8K E2E training tests on 1 L20 GPU with hf for santiy
|
||||
- name: Running GSM8K E2E training tests on 1 L20 GPU with hf for sanity
|
||||
run: |
|
||||
ray stop --force
|
||||
bash tests/special_e2e/ppo_trainer/run_single_gpu.sh
|
||||
# HF sanity
|
||||
- name: Running GSM8K E2E training tests on 1 L20 GPU with engine interface for sanity.
|
||||
run: |
|
||||
ray stop --force
|
||||
bash tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
|
||||
# Function RM
|
||||
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8)
|
||||
run: |
|
||||
|
@ -38,4 +38,5 @@ python3 -m verl.trainer.main_ppo \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=20 \
|
||||
trainer.test_freq=1 \
|
||||
trainer.use_legacy_worker_impl=auto \
|
||||
trainer.total_epochs=15 $@
|
||||
|
25
tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
Normal file
25
tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
Normal file
@ -0,0 +1,25 @@
|
||||
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
||||
data.train_files=$HOME/data/gsm8k/train.parquet \
|
||||
data.val_files=$HOME/data/gsm8k/test.parquet \
|
||||
data.train_batch_size=256 \
|
||||
data.max_prompt_length=512 \
|
||||
data.max_response_length=256 \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
||||
critic.optim.lr=1e-5 \
|
||||
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
|
||||
critic.ppo_micro_batch_size_per_gpu=4 \
|
||||
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||
trainer.logger=['console'] \
|
||||
trainer.val_before_train=False \
|
||||
trainer.n_gpus_per_node=1 \
|
||||
trainer.nnodes=1 \
|
||||
actor_rollout_ref.rollout.name=hf \
|
||||
trainer.use_legacy_worker_impl=disable \
|
||||
trainer.total_training_steps=2
|
@ -36,6 +36,7 @@ CUDA_KEYWORD_CHECK_WHITELIST = [
|
||||
"verl/trainer/ppo/ray_trainer.py", # appear in default device_name
|
||||
"verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type
|
||||
"verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name
|
||||
"verl/workers/engine/fsdp/engine_impl.py",
|
||||
]
|
||||
|
||||
# directory or file path must contain keyword "nccl"
|
||||
|
@ -214,6 +214,7 @@ trainer:
|
||||
max_critic_ckpt_to_keep: null
|
||||
ray_wait_register_center_timeout: 300
|
||||
device: cuda
|
||||
use_legacy_worker_impl: auto
|
||||
data:
|
||||
tokenizer: null
|
||||
use_shm: false
|
||||
|
@ -322,6 +322,10 @@ trainer:
|
||||
# Device to run training on (e.g., "cuda", "cpu")
|
||||
device: cuda
|
||||
|
||||
# whether to use legacy worker implementation
|
||||
# mode: "auto", "enable", or "disable"
|
||||
use_legacy_worker_impl: auto
|
||||
|
||||
# configs related to ray initialization
|
||||
ray_init:
|
||||
|
||||
|
@ -127,7 +127,20 @@ class TaskRunner:
|
||||
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
||||
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
|
||||
|
||||
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
|
||||
if use_legacy_worker_impl in ["auto", "enable"]:
|
||||
# import warnings
|
||||
# warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
|
||||
# Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.")
|
||||
from verl.workers.fsdp_workers import CriticWorker
|
||||
elif use_legacy_worker_impl == "disable":
|
||||
from verl.workers.roles import CriticWorker
|
||||
|
||||
print("Using new worker implementation")
|
||||
else:
|
||||
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
|
||||
|
||||
actor_rollout_cls = (
|
||||
AsyncActorRolloutRefWorker
|
||||
|
17
verl/workers/engine/__init__.py
Normal file
17
verl/workers/engine/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .base import BaseEngine, EngineRegistry
|
||||
from .fsdp import FSDPEngine
|
||||
|
||||
__all__ = ["BaseEngine", "EngineRegistry", "FSDPEngine"]
|
235
verl/workers/engine/base.py
Normal file
235
verl/workers/engine/base.py
Normal file
@ -0,0 +1,235 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The abstract base class defining the interface for model training engines.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
|
||||
class BaseEngine:
|
||||
"""
|
||||
Abstract base class defining the interface for model training engines.
|
||||
|
||||
Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the BaseEngine.
|
||||
|
||||
Args:
|
||||
config: Configuration object containing parameters for engine setup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def init_model(self):
|
||||
"""
|
||||
Instantiate or load the model, optimizer, and learning rate scheduler.
|
||||
|
||||
Should prepare all components necessary for training or evaluation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_mode(self):
|
||||
"""
|
||||
Context manager entry for switching the engine and model into training mode.
|
||||
|
||||
Usage:
|
||||
with engine.train_mode():
|
||||
# runs in training mode
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval_mode(self):
|
||||
"""
|
||||
Context manager entry for switching the engine and model into evaluation mode.
|
||||
|
||||
Usage:
|
||||
with engine.eval_mode():
|
||||
# runs in evaluation mode
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_batch(
|
||||
self,
|
||||
data: DataProto,
|
||||
post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Perform inference on a mini batch of data.
|
||||
|
||||
Args:
|
||||
data: The input data for inference, typically containing tensors and metadata.
|
||||
post_fn: A post-processing function that takes a micro-batch and predictions as input,
|
||||
and returns a tuple containing processed predictions and a dictionary of outputs.
|
||||
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
data: DataProto,
|
||||
loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Perform a training step on a mini-batch of data.
|
||||
|
||||
Args:
|
||||
data (DataProto): The input data for training, typically containing tensors and metadata.
|
||||
loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.
|
||||
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def optimizer_zero_grad(self):
|
||||
"""
|
||||
Zero out gradients of all parameters before starting a new backward pass.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def optimizer_step(self):
|
||||
"""
|
||||
Perform an optimization step to update model parameters based on accumulated gradients.
|
||||
|
||||
Returns:
|
||||
grad_norm (float): The norm of the gradients before clipping or update.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def lr_scheduler_step(self):
|
||||
"""
|
||||
Advance the learning rate scheduler by one step.
|
||||
|
||||
Returns:
|
||||
current_lr (float or list[float]): Updated learning rate(s).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def shard_data(self, data):
|
||||
"""
|
||||
Shard or partition data for distributed training or parallel execution.
|
||||
|
||||
Args:
|
||||
data: Data structure to be sharded across devices/workers.
|
||||
|
||||
Returns:
|
||||
Sharded data in the same format as input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def unshard_data(self, data):
|
||||
"""
|
||||
Reconstruct or gather sharded data back to a unified format.
|
||||
|
||||
Args:
|
||||
data: Sharded data structure to reconstruct.
|
||||
|
||||
Returns:
|
||||
Unsharded, combined data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device: str, model: bool = True, optimizer: bool = True):
|
||||
"""
|
||||
Move model parameters, optimizer states, or both to the specified device.
|
||||
|
||||
Args:
|
||||
device: Target device identifier.
|
||||
model: If True, move the model.
|
||||
optimizer: If True, move the optimizer states.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
"""
|
||||
Save model, optimizer, and scheduler states to a checkpoint.
|
||||
|
||||
Args:
|
||||
local_path: Local filesystem path to save checkpoint.
|
||||
hdfs_path: Optional HDFS path to copy checkpoint.
|
||||
global_step: Integer training step number for naming.
|
||||
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
"""
|
||||
Load model, optimizer, and scheduler states from a checkpoint.
|
||||
|
||||
Args:
|
||||
local_path: Local filesystem path of the checkpoint.
|
||||
hdfs_path: Optional HDFS path where checkpoint is stored.
|
||||
del_local_after_load: Whether to delete local copy after loading.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EngineRegistry:
|
||||
"""
|
||||
A registry for managing and instantiating different types of training engines.
|
||||
|
||||
This class uses a dictionary to store engine classes, mapping a string key to each class.
|
||||
It provides a decorator `register` to add new engines to the registry and a `new` method
|
||||
to create an instance of a registered engine.
|
||||
"""
|
||||
|
||||
_engines = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, key):
|
||||
"""
|
||||
A class method decorator that registers an engine class with a given key.
|
||||
|
||||
This allows for dynamic instantiation of engine classes by their registered key.
|
||||
|
||||
Args:
|
||||
key (str): The identifier to associate with the engine class.
|
||||
|
||||
Returns:
|
||||
A decorator function that takes an engine class and registers it.
|
||||
"""
|
||||
|
||||
def decorator(engine_class):
|
||||
assert issubclass(engine_class, BaseEngine)
|
||||
cls._engines[key] = engine_class
|
||||
return engine_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def new(cls, key, *args, **kwargs):
|
||||
"""
|
||||
Function to create a new training engine instance based on the provided config.
|
||||
Args:
|
||||
key: A configuration object containing the engine key and other settings.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
Returns:
|
||||
engine: An instance of the training engine corresponding to the config.
|
||||
Raises:
|
||||
NotImplementedError: If the engine key in the config does not match any known engines.
|
||||
"""
|
||||
if key in cls._engines:
|
||||
return cls._engines[key](*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown engine: {key}")
|
16
verl/workers/engine/fsdp/__init__.py
Normal file
16
verl/workers/engine/fsdp/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .engine_impl import FSDPEngine
|
||||
|
||||
__all__ = ["FSDPEngine"]
|
727
verl/workers/engine/fsdp/engine_impl.py
Normal file
727
verl/workers/engine/fsdp/engine_impl.py
Normal file
@ -0,0 +1,727 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP)
|
||||
"""
|
||||
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import OmegaConf
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl import DataProto
|
||||
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
from verl.utils.activation_offload import enable_activation_offloading
|
||||
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
||||
from verl.utils.debug import log_gpu_memory_usage
|
||||
from verl.utils.device import (
|
||||
get_device_id,
|
||||
get_device_name,
|
||||
get_torch_device,
|
||||
is_cuda_available,
|
||||
is_npu_available,
|
||||
)
|
||||
from verl.utils.flops_counter import FlopsCounter
|
||||
from verl.utils.fs import copy_to_local
|
||||
from verl.utils.fsdp_utils import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
apply_fsdp2,
|
||||
fsdp2_clip_grad_norm_,
|
||||
fsdp2_load_full_state_dict,
|
||||
get_fsdp_wrap_policy,
|
||||
get_init_weight_context_manager,
|
||||
init_fn,
|
||||
load_fsdp_model_to_gpu,
|
||||
load_fsdp_optimizer,
|
||||
offload_fsdp_model_to_cpu,
|
||||
offload_fsdp_optimizer,
|
||||
)
|
||||
from verl.utils.import_utils import import_external_libs
|
||||
from verl.utils.py_functional import append_to_dict, convert_to_regular_types
|
||||
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
|
||||
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
||||
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
|
||||
|
||||
if is_cuda_available:
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
|
||||
elif is_npu_available:
|
||||
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input
|
||||
|
||||
from ..base import BaseEngine, EngineRegistry
|
||||
from .utils import create_device_mesh, get_sharding_strategy
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
device_name = get_device_name()
|
||||
|
||||
|
||||
@EngineRegistry.register("fsdp")
|
||||
class FSDPEngine(BaseEngine):
|
||||
"""
|
||||
Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).
|
||||
|
||||
Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the FSDPEngine.
|
||||
|
||||
Sets up distributed device meshes, LoRA, and offload policies based on config.
|
||||
|
||||
Args:
|
||||
config: Configuration object with FSDP and model settings.
|
||||
"""
|
||||
self.config = config
|
||||
self.rank = torch.distributed.get_rank()
|
||||
# build device mesh for Ulysses Sequence Parallel
|
||||
world_size = torch.distributed.get_world_size()
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
fsdp_size = self.config.model.fsdp_config.fsdp_size
|
||||
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
|
||||
self.use_remove_padding = config.model.get("use_remove_padding", False)
|
||||
|
||||
self.ulysses_device_mesh = None
|
||||
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
|
||||
dp = world_size // self.ulysses_sequence_parallel_size
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
self.ulysses_device_mesh = init_device_mesh(
|
||||
device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
|
||||
)
|
||||
|
||||
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
|
||||
|
||||
# set FSDP offload params
|
||||
self._is_offload_param = self.config.model.fsdp_config.param_offload
|
||||
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
|
||||
|
||||
# normalize config
|
||||
self.config.ppo_mini_batch_size *= self.config.rollout_n
|
||||
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
|
||||
if self.config.ppo_micro_batch_size is not None:
|
||||
self.config.ppo_micro_batch_size //= (
|
||||
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
|
||||
)
|
||||
self.config.forward_micro_batch_size //= (
|
||||
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
|
||||
)
|
||||
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
|
||||
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
|
||||
|
||||
if self.config.ppo_micro_batch_size_per_gpu is not None:
|
||||
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, (
|
||||
f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by "
|
||||
f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
|
||||
)
|
||||
assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, (
|
||||
f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than "
|
||||
f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
|
||||
)
|
||||
self._is_lora = self.config.model.get("lora_rank", 0) > 0
|
||||
|
||||
def init_model(self):
|
||||
"""
|
||||
Build the model, optimizer, and learning rate scheduler under FSDP.
|
||||
|
||||
Applies device, dtype, and precision configurations, including mixed precision.
|
||||
Sets up checkpoint manager and FLOPs counter.
|
||||
"""
|
||||
# This is used to import external_lib into the huggingface systems
|
||||
import_external_libs(self.config.model.get("external_lib", None))
|
||||
|
||||
self.module, self.optimizer, self.lr_scheduler = self._build_model_optimizer(self.config)
|
||||
|
||||
if self._is_offload_param:
|
||||
offload_fsdp_model_to_cpu(self.module)
|
||||
log_gpu_memory_usage("After offload model during init", logger=logger)
|
||||
if self._is_offload_optimizer:
|
||||
offload_fsdp_optimizer(optimizer=self.optimizer)
|
||||
log_gpu_memory_usage("After offload optimizer during init", logger=logger)
|
||||
|
||||
self.flops_counter = FlopsCounter(self.model_config)
|
||||
self.checkpoint_manager = FSDPCheckpointManager(
|
||||
model=self.module,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
processing_class=self.processor if self.processor is not None else self.tokenizer,
|
||||
checkpoint_contents=self.config.checkpoint,
|
||||
)
|
||||
|
||||
def _build_model_optimizer(self, config):
|
||||
# the following line is necessary
|
||||
from torch import optim
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
|
||||
from verl.utils.model import load_valuehead_model, print_model_size
|
||||
from verl.utils.torch_dtypes import PrecisionType
|
||||
|
||||
use_shm = config.model.get("use_shm", False)
|
||||
local_path = copy_to_local(config.model.path, use_shm=use_shm)
|
||||
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
|
||||
# using random initialized model from any architecture. May not be the same as Actor.
|
||||
|
||||
tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)
|
||||
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
||||
self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
||||
|
||||
if self.config.model.get("custom_chat_template", None) is not None:
|
||||
if self.processor is not None:
|
||||
self.processor.chat_template = self.config.model.custom_chat_template
|
||||
else:
|
||||
self.tokenizer.chat_template = self.config.model.custom_chat_template
|
||||
|
||||
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
||||
override_config_kwargs = {
|
||||
"bos_token_id": self.tokenizer.bos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"pad_token_id": self.tokenizer.pad_token_id,
|
||||
}
|
||||
override_config_kwargs.update(override_config)
|
||||
if self.rank == 0:
|
||||
print(f"Engine overriding config {override_config_kwargs}")
|
||||
|
||||
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
|
||||
torch_dtype = PrecisionType.to_dtype(torch_dtype)
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
local_path,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=config.model.get("trust_remote_code", False),
|
||||
)
|
||||
model_config.num_labels = 1
|
||||
# patch for kimi-vl
|
||||
if getattr(model_config, "model_type", None) == "kimi_vl":
|
||||
model_config.text_config.topk_method = "greedy"
|
||||
|
||||
init_context = get_init_weight_context_manager(
|
||||
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
|
||||
)
|
||||
|
||||
with init_context(), warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model_config.classifier_dropout = 0.0
|
||||
model_config.hidden_dropout = "0"
|
||||
model_config.summary_dropout_prob = 0.0
|
||||
|
||||
module = load_valuehead_model(
|
||||
local_path,
|
||||
torch_dtype,
|
||||
model_config,
|
||||
config.model.get("trust_remote_code", False),
|
||||
)
|
||||
|
||||
apply_monkey_patch(
|
||||
model=module,
|
||||
use_remove_padding=self.use_remove_padding,
|
||||
ulysses_sp_size=self.ulysses_sequence_parallel_size,
|
||||
)
|
||||
|
||||
# some parameters may not in torch_dtype
|
||||
module.to(torch_dtype)
|
||||
|
||||
if config.model.get("enable_gradient_checkpointing", False):
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
if self._is_lora:
|
||||
print("Applying LoRA to the module")
|
||||
module.enable_input_require_grads()
|
||||
# Convert config to regular Python types before creating PEFT model
|
||||
lora_config = {
|
||||
"task_type": TaskType.CAUSAL_LM,
|
||||
"r": self.config.model.lora_rank,
|
||||
"lora_alpha": self.config.model.lora_alpha,
|
||||
"target_modules": convert_to_regular_types(self.config.model.target_modules),
|
||||
"bias": "none",
|
||||
}
|
||||
module = get_peft_model(module, LoraConfig(**lora_config))
|
||||
|
||||
if self.rank == 0:
|
||||
print_model_size(module)
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
fsdp_config = self.config.model.fsdp_config
|
||||
mixed_precision_config = fsdp_config.get("mixed_precision", None)
|
||||
if mixed_precision_config is not None:
|
||||
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
|
||||
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
|
||||
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
|
||||
else:
|
||||
param_dtype = torch.bfloat16
|
||||
reduce_dtype = torch.float32
|
||||
buffer_dtype = torch.float32
|
||||
|
||||
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
|
||||
|
||||
auto_wrap_policy = get_fsdp_wrap_policy(
|
||||
module=module,
|
||||
config=self.config.model.fsdp_config.wrap_policy,
|
||||
is_lora=self.config.model.get("lora_rank", 0) > 0,
|
||||
)
|
||||
|
||||
log_gpu_memory_usage("Before FSDP", logger=None)
|
||||
|
||||
fsdp_mesh = self.device_mesh
|
||||
sharding_strategy = get_sharding_strategy(fsdp_mesh)
|
||||
|
||||
# Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation
|
||||
if config.strategy == "fsdp":
|
||||
module = FSDP(
|
||||
module,
|
||||
param_init_fn=init_fn,
|
||||
use_orig_params=False,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
device_id=get_device_id(),
|
||||
sharding_strategy=sharding_strategy,
|
||||
mixed_precision=mixed_precision,
|
||||
sync_module_states=True,
|
||||
forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
|
||||
device_mesh=self.device_mesh,
|
||||
cpu_offload=None,
|
||||
)
|
||||
elif config.strategy == "fsdp2":
|
||||
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
|
||||
)
|
||||
offload_policy = None
|
||||
if fsdp_config.offload_policy:
|
||||
self._is_offload_param = False
|
||||
self._is_offload_optimizer = False
|
||||
offload_policy = CPUOffloadPolicy(pin_memory=True)
|
||||
|
||||
fsdp_kwargs = {
|
||||
"mesh": fsdp_mesh,
|
||||
"mp_policy": mp_policy,
|
||||
"offload_policy": offload_policy,
|
||||
"reshard_after_forward": fsdp_config.reshard_after_forward,
|
||||
}
|
||||
full_state = module.state_dict()
|
||||
apply_fsdp2(module, fsdp_kwargs, fsdp_config)
|
||||
fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown strategy {config.strategy}")
|
||||
|
||||
if config.model.get("enable_activation_offload", False):
|
||||
enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False)
|
||||
enable_activation_offloading(module, config.strategy, enable_gradient_checkpointing)
|
||||
|
||||
log_gpu_memory_usage("After FSDP", logger=None)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
module.parameters(),
|
||||
lr=config.optim.lr,
|
||||
betas=config.optim.get("betas", (0.9, 0.999)),
|
||||
weight_decay=config.optim.get("weight_decay", 1e-2),
|
||||
)
|
||||
|
||||
total_steps = config.optim.get("total_training_steps", 0)
|
||||
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
|
||||
warmup_style = config.optim.get("warmup_style", "constant")
|
||||
if num_warmup_steps < 0:
|
||||
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
|
||||
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
||||
|
||||
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
|
||||
|
||||
if warmup_style == "constant":
|
||||
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
|
||||
elif warmup_style == "cosine":
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
|
||||
return module, optimizer, lr_scheduler
|
||||
|
||||
def train_mode(self):
|
||||
"""
|
||||
Return a context manager that switches to training mode with FSDP-specific handling.
|
||||
|
||||
Includes parameter and optimizer offload entry/exit.
|
||||
"""
|
||||
return EngineTrainModeCtx(self)
|
||||
|
||||
def eval_mode(self):
|
||||
"""
|
||||
Return a context manager that switches to evaluation mode with FSDP-specific handling.
|
||||
|
||||
Includes activation offload entry/exit.
|
||||
"""
|
||||
return EngineEvalModeCtx(self)
|
||||
|
||||
def shard_data(self, data):
|
||||
"""
|
||||
Preprocess data into sharded format via UlyssesShardingManager.
|
||||
"""
|
||||
return self.ulysses_sharding_manager.preprocess_data(data)
|
||||
|
||||
def unshard_data(self, data):
|
||||
"""
|
||||
Postprocess data from sharded format back to full format.
|
||||
"""
|
||||
return self.ulysses_sharding_manager.postprocess_data(data)
|
||||
|
||||
def get_default_ctx(self):
|
||||
use_value_head_model = hasattr(self.module, "v_head")
|
||||
ctx = {
|
||||
"use_value_head_model": use_value_head_model,
|
||||
"ulysses_sequence_parallel_size": self.ulysses_sequence_parallel_size,
|
||||
}
|
||||
return ctx
|
||||
|
||||
def _forward_micro_batch(self, micro_batch):
|
||||
multi_modal_inputs = {}
|
||||
if "multi_modal_inputs" in micro_batch.keys():
|
||||
for key in micro_batch["multi_modal_inputs"][0].keys():
|
||||
multi_modal_inputs[key] = torch.cat(
|
||||
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_name, dtype=torch.bfloat16):
|
||||
input_ids = micro_batch["input_ids"]
|
||||
batch, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch["attention_mask"]
|
||||
position_ids = micro_batch["position_ids"]
|
||||
if position_ids.dim() == 3: # qwen2vl mrope
|
||||
position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, *_ = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask
|
||||
) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
if position_ids.dim() == 3:
|
||||
position_ids_rmpad = (
|
||||
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(1)
|
||||
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
|
||||
else:
|
||||
position_ids_rmpad = index_first_axis(
|
||||
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
|
||||
).transpose(0, 1)
|
||||
|
||||
# pad and slice the inputs if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
|
||||
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
|
||||
)
|
||||
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
preds = self.module(
|
||||
input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
**multi_modal_inputs,
|
||||
use_cache=False,
|
||||
) # prevent model thinks we are generating
|
||||
|
||||
if hasattr(self.module, "v_head"):
|
||||
# For trl.AutoModelForCausalLMWithValueHead
|
||||
preds_rmpad = preds[2].squeeze(0).unsqueeze(-1)
|
||||
else:
|
||||
preds_rmpad = preds.logits
|
||||
preds_rmpad = preds_rmpad.squeeze(0) # (total_nnz)
|
||||
|
||||
# gather output if sp > 1
|
||||
if self.ulysses_sequence_parallel_size > 1:
|
||||
preds_rmpad = gather_outpus_and_unpad(preds_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
||||
|
||||
# pad it back
|
||||
preds = pad_input(preds_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
|
||||
else:
|
||||
preds = self.module(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
**multi_modal_inputs,
|
||||
use_cache=False,
|
||||
) # prevent model thinks we are generating
|
||||
if hasattr(self.module, "v_head"):
|
||||
# For trl.AutoModelForCausalLMWithValueHead
|
||||
preds = preds[2]
|
||||
else:
|
||||
preds = preds.logits
|
||||
|
||||
return preds
|
||||
|
||||
def infer_batch(
|
||||
self,
|
||||
data: DataProto,
|
||||
post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Perform inference on a mini batch of data.
|
||||
|
||||
Args:
|
||||
data: The input data for inference, typically containing tensors and metadata.
|
||||
post_fn: A post-processing function that takes a micro-batch and predictions as input,
|
||||
and returns a tuple containing processed predictions and a dictionary of outputs.
|
||||
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.
|
||||
"""
|
||||
assert self.mode == "eval"
|
||||
micro_batch_size = data.meta_info["micro_batch_size"]
|
||||
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
|
||||
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
|
||||
if has_multi_modal_inputs:
|
||||
num_micro_batches = data.batch.batch_size[0] // micro_batch_size
|
||||
non_tensor_select_keys = ["multi_modal_inputs"]
|
||||
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
|
||||
elif use_dynamic_bsz:
|
||||
# split using dynamic bsz
|
||||
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
|
||||
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = batch.split(micro_batch_size)
|
||||
|
||||
preds_list = {}
|
||||
for micro_batch in micro_batches:
|
||||
if isinstance(micro_batch, DataProto):
|
||||
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
|
||||
|
||||
with torch.no_grad():
|
||||
# micro_batch_preds would be a dict[str, torch.Tensor]
|
||||
preds = self._forward_micro_batch(micro_batch)
|
||||
_, outputs = post_fn(micro_batch, preds)
|
||||
assert isinstance(outputs, dict)
|
||||
|
||||
# append micro batch preds to dict[str, List[torch.Tensor]]
|
||||
append_to_dict(preds_list, outputs)
|
||||
|
||||
# reorganize mini batch preds from
|
||||
# dict[str, List[torch.Tensor]] to dict[str, torch.Tensor]
|
||||
mini_batch_preds = {}
|
||||
for key, t_list in preds_list.items():
|
||||
t_concat = torch.concat(t_list, dim=0)
|
||||
|
||||
if use_dynamic_bsz:
|
||||
indices = list(itertools.chain.from_iterable(indices))
|
||||
assert len(indices) == t_concat.size(0), f"{len(indices)} vs. {t_concat.size()}"
|
||||
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
|
||||
t_concat = t_concat[revert_indices]
|
||||
|
||||
mini_batch_preds[key] = t_concat
|
||||
|
||||
return mini_batch_preds
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
data: DataProto,
|
||||
loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Perform a training step on a mini-batch of data.
|
||||
|
||||
Args:
|
||||
data (DataProto): The input data for training, typically containing tensors and metadata.
|
||||
loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.
|
||||
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.
|
||||
"""
|
||||
assert self.mode == "train"
|
||||
# split batch into micro_batches
|
||||
mini_batch = data
|
||||
select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids"]
|
||||
if "multi_modal_inputs" in mini_batch:
|
||||
non_tensor_select_keys = ["multi_modal_inputs"]
|
||||
num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
|
||||
micro_batches = mini_batch.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
|
||||
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
|
||||
elif self.config.use_dynamic_bsz:
|
||||
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
|
||||
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
|
||||
else:
|
||||
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
|
||||
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
|
||||
|
||||
mini_batch_metrics = {}
|
||||
for micro_batch in micro_batches:
|
||||
# Support all devices
|
||||
if isinstance(micro_batch, DataProto):
|
||||
micro_batch = {**micro_batch.batch.to(get_device_id()), **micro_batch.non_tensor_batch}
|
||||
else:
|
||||
micro_batch = micro_batch.to(get_device_id()) # critic device is cpu when using offload
|
||||
|
||||
preds = self._forward_micro_batch(micro_batch)
|
||||
loss, micro_batch_metrics = loss_fn(micro_batch, preds)
|
||||
append_to_dict(mini_batch_metrics, micro_batch_metrics)
|
||||
loss.backward()
|
||||
|
||||
return mini_batch_metrics
|
||||
|
||||
def optimizer_zero_grad(self):
|
||||
"""
|
||||
Zero gradients and enforce FSDP grad-clipping logic.
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def optimizer_step(self):
|
||||
"""
|
||||
Clip gradients, skip update if non-finite, and step optimizer.
|
||||
|
||||
Returns:
|
||||
grad_norm (float): Norm of gradients before clipping.
|
||||
"""
|
||||
assert self.config.grad_clip is not None
|
||||
|
||||
if isinstance(self.module, FSDP):
|
||||
grad_norm = self.module.clip_grad_norm_(self.config.grad_clip)
|
||||
elif isinstance(self.module, FSDPModule):
|
||||
grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip)
|
||||
|
||||
# if grad_norm is not finite, skip the update
|
||||
if not torch.isfinite(grad_norm):
|
||||
print(f"WARN: grad_norm is not finite: {grad_norm}")
|
||||
self.optimizer.zero_grad()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
return grad_norm
|
||||
|
||||
def lr_scheduler_step(self):
|
||||
"""
|
||||
Advance FSDP scheduler and return updated learning rate.
|
||||
"""
|
||||
self.lr_scheduler.step()
|
||||
lr = self.lr_scheduler.get_last_lr()
|
||||
return lr
|
||||
|
||||
def to(self, device: str, model: bool = True, optimizer: bool = True):
|
||||
"""
|
||||
Move FSDP model and/or optimizer to CPU or GPU with offload support.
|
||||
"""
|
||||
assert device in ("cuda", "cpu")
|
||||
if device == "cuda":
|
||||
if not self.config.model.fsdp_config.param_offload:
|
||||
if model:
|
||||
load_fsdp_model_to_gpu(self.model_module)
|
||||
if optimizer and self.optimizer is not None:
|
||||
load_fsdp_optimizer(self.optimizer, device)
|
||||
gc.collect()
|
||||
elif device == "cpu":
|
||||
if not self.config.model.fsdp_config.param_offload:
|
||||
if model:
|
||||
offload_fsdp_model_to_cpu(self.model_module)
|
||||
if optimizer and self.optimizer is not None:
|
||||
offload_fsdp_optimizer(self.optimizer)
|
||||
else:
|
||||
raise ValueError(f"Invalid device type: {device}")
|
||||
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
"""
|
||||
Save FSDP checkpoint, handling parameter offload as needed.
|
||||
"""
|
||||
if self._is_offload_param:
|
||||
load_fsdp_model_to_gpu(self.module)
|
||||
|
||||
self.checkpoint_manager.save_checkpoint(
|
||||
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
|
||||
)
|
||||
|
||||
torch.distributed.barrier()
|
||||
if self._is_offload_param:
|
||||
offload_fsdp_model_to_cpu(self.module)
|
||||
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
"""
|
||||
Load FSDP checkpoint, restoring parameters and optimizer state.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if self._is_offload_param:
|
||||
load_fsdp_model_to_gpu(self.module)
|
||||
|
||||
self.checkpoint_manager.load_checkpoint(
|
||||
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
|
||||
)
|
||||
|
||||
torch.distributed.barrier()
|
||||
if self._is_offload_param:
|
||||
offload_fsdp_model_to_cpu(self.module)
|
||||
|
||||
if self._is_offload_optimizer:
|
||||
offload_fsdp_optimizer(self.optimizer)
|
||||
|
||||
|
||||
class EngineEvalModeCtx:
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def __enter__(self):
|
||||
self.engine.mode = "eval"
|
||||
if self.engine._is_offload_param:
|
||||
load_fsdp_model_to_gpu(self.engine.module)
|
||||
|
||||
self.engine.ulysses_sharding_manager.__enter__()
|
||||
self.engine.module.eval()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
|
||||
if self.engine._is_offload_param:
|
||||
offload_fsdp_model_to_cpu(self.engine.module)
|
||||
self.engine.mode = None
|
||||
|
||||
|
||||
class EngineTrainModeCtx:
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def __enter__(self):
|
||||
self.engine.mode = "train"
|
||||
if self.engine._is_offload_param:
|
||||
load_fsdp_model_to_gpu(self.engine.module)
|
||||
if self.engine._is_offload_optimizer:
|
||||
load_fsdp_optimizer(optimizer=self.engine.optimizer, device_id=get_torch_device().current_device())
|
||||
|
||||
self.engine.ulysses_sharding_manager.__enter__()
|
||||
self.engine.module.train()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
if self.engine._is_offload_param:
|
||||
offload_fsdp_model_to_cpu(self.engine.module)
|
||||
if self.engine._is_offload_optimizer:
|
||||
offload_fsdp_optimizer(optimizer=self.optimizer)
|
||||
self.engine.mode = None
|
61
verl/workers/engine/fsdp/utils.py
Normal file
61
verl/workers/engine/fsdp/utils.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from verl.utils.device import get_device_name
|
||||
|
||||
|
||||
def create_device_mesh(world_size, fsdp_size):
|
||||
"""
|
||||
Create a device mesh for distributed training based on the world size and FSDP size.
|
||||
|
||||
Args:
|
||||
world_size (int): Total number of processes in the distributed training setup.
|
||||
fsdp_size (int): Size of the Fully Sharded Data Parallel (FSDP) group.
|
||||
|
||||
Returns:
|
||||
torch.distributed.device_mesh.DeviceMesh: The initialized device mesh.
|
||||
"""
|
||||
device_name = get_device_name()
|
||||
if fsdp_size < 0 or fsdp_size >= world_size:
|
||||
device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
|
||||
else:
|
||||
device_mesh = init_device_mesh(
|
||||
device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
|
||||
)
|
||||
return device_mesh
|
||||
|
||||
|
||||
def get_sharding_strategy(device_mesh):
|
||||
"""
|
||||
Determine the appropriate sharding strategy based on the number of dimensions of the device mesh.
|
||||
|
||||
Args:
|
||||
device_mesh (torch.distributed.device_mesh.DeviceMesh): The device mesh used for distributed training.
|
||||
|
||||
Returns:
|
||||
torch.distributed.fsdp.ShardingStrategy: The sharding strategy to be used with FSDP.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the number of dimensions of the device mesh is neither 1 nor 2.
|
||||
"""
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
|
||||
if device_mesh.ndim == 1:
|
||||
sharding_strategy = ShardingStrategy.FULL_SHARD
|
||||
elif device_mesh.ndim == 2:
|
||||
sharding_strategy = ShardingStrategy.HYBRID_SHARD
|
||||
else:
|
||||
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
|
||||
return sharding_strategy
|
13
verl/workers/engine/megatron/__init__.py
Normal file
13
verl/workers/engine/megatron/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
166
verl/workers/engine/megatron/engine_impl.py
Normal file
166
verl/workers/engine/megatron/engine_impl.py
Normal file
@ -0,0 +1,166 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from verl import DataProto
|
||||
|
||||
from ..base import BaseEngine, EngineRegistry
|
||||
|
||||
|
||||
@EngineRegistry.register("megatron")
|
||||
class MegatronEngine(BaseEngine):
|
||||
def __init__(self, config):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def train_mode(self):
|
||||
"""
|
||||
Context manager entry for switching the engine and model into training mode.
|
||||
|
||||
Usage:
|
||||
with engine.train_mode():
|
||||
# runs in training mode
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval_mode(self):
|
||||
"""
|
||||
Context manager entry for switching the engine and model into evaluation mode.
|
||||
|
||||
Usage:
|
||||
with engine.eval_mode():
|
||||
# runs in evaluation mode
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_batch(
|
||||
self,
|
||||
data: DataProto,
|
||||
post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Perform inference on a mini batch of data.
|
||||
|
||||
Args:
|
||||
data: The input data for inference, typically containing tensors and metadata.
|
||||
post_fn: A post-processing function that takes a micro-batch and predictions as input,
|
||||
and returns a tuple containing processed predictions and a dictionary of outputs.
|
||||
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
data: DataProto,
|
||||
loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Perform a training step on a mini-batch of data.
|
||||
|
||||
Args:
|
||||
data (DataProto): The input data for training, typically containing tensors and metadata.
|
||||
loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.
|
||||
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def optimizer_zero_grad(self):
|
||||
"""
|
||||
Zero out gradients of all parameters before starting a new backward pass.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def optimizer_step(self):
|
||||
"""
|
||||
Perform an optimization step to update model parameters based on accumulated gradients.
|
||||
|
||||
Returns:
|
||||
grad_norm (float): The norm of the gradients before clipping or update.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def lr_scheduler_step(self):
|
||||
"""
|
||||
Advance the learning rate scheduler by one step.
|
||||
|
||||
Returns:
|
||||
current_lr (float or list[float]): Updated learning rate(s).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def shard_data(self, data):
|
||||
"""
|
||||
Shard or partition data for distributed training or parallel execution.
|
||||
|
||||
Args:
|
||||
data: Data structure to be sharded across devices/workers.
|
||||
|
||||
Returns:
|
||||
Sharded data in the same format as input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def unshard_data(self, data):
|
||||
"""
|
||||
Reconstruct or gather sharded data back to a unified format.
|
||||
|
||||
Args:
|
||||
data: Sharded data structure to reconstruct.
|
||||
|
||||
Returns:
|
||||
Unsharded, combined data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device: str, model: bool = True, optimizer: bool = True):
|
||||
"""
|
||||
Move model parameters, optimizer states, or both to the specified device.
|
||||
|
||||
Args:
|
||||
device: Target device identifier.
|
||||
model: If True, move the model.
|
||||
optimizer: If True, move the optimizer states.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
"""
|
||||
Save model, optimizer, and scheduler states to a checkpoint.
|
||||
|
||||
Args:
|
||||
local_path: Local filesystem path to save checkpoint.
|
||||
hdfs_path: Optional HDFS path to copy checkpoint.
|
||||
global_step: Integer training step number for naming.
|
||||
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
"""
|
||||
Load model, optimizer, and scheduler states from a checkpoint.
|
||||
|
||||
Args:
|
||||
local_path: Local filesystem path of the checkpoint.
|
||||
hdfs_path: Optional HDFS path where checkpoint is stored.
|
||||
del_local_after_load: Whether to delete local copy after loading.
|
||||
"""
|
||||
raise NotImplementedError
|
17
verl/workers/roles/__init__.py
Normal file
17
verl/workers/roles/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .critic import CriticWorker
|
||||
|
||||
__all__ = ["CriticWorker"]
|
51
verl/workers/roles/actor.py
Normal file
51
verl/workers/roles/actor.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from verl import DataProto
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.base.decorator import Dispatch, register
|
||||
|
||||
|
||||
class ActorWorker(Worker):
|
||||
"""
|
||||
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
|
||||
or a hybrid engine based on the config.rollout
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
raise NotImplementedError
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def init_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
||||
def update_actor(self, data: DataProto):
|
||||
raise NotImplementedError
|
||||
|
||||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
||||
def compute_log_prob(self, data: DataProto):
|
||||
raise NotImplementedError
|
||||
|
||||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
||||
def compute_ref_log_prob(self, data: DataProto):
|
||||
raise NotImplementedError
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
|
||||
raise NotImplementedError
|
183
verl/workers/roles/critic.py
Normal file
183
verl/workers/roles/critic.py
Normal file
@ -0,0 +1,183 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The main entry point to run the PPO algorithm
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from codetiming import Timer
|
||||
|
||||
from verl import DataProto
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.base.decorator import Dispatch, register
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
from verl.utils.device import (
|
||||
get_device_id,
|
||||
get_nccl_backend,
|
||||
)
|
||||
from verl.utils.profiler import DistProfiler, DistProfilerExtension
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.workers.engine import EngineRegistry
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class CriticWorker(Worker, DistProfilerExtension):
|
||||
def __init__(self, config):
|
||||
Worker.__init__(self)
|
||||
DistProfilerExtension.__init__(
|
||||
self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")))
|
||||
)
|
||||
import torch.distributed
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend=get_nccl_backend())
|
||||
self.config = config
|
||||
self.engine = EngineRegistry.new(self.config.strategy, self.config)
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def init_model(self):
|
||||
self.engine.init_model()
|
||||
|
||||
def _post_fn_values(self, micro_batch, preds):
|
||||
response_length = micro_batch["responses"].size(-1)
|
||||
values = preds[:, -response_length - 1 : -1]
|
||||
|
||||
use_remove_padding = self.config.model.get("use_remove_padding", False)
|
||||
if not use_remove_padding:
|
||||
values = values.squeeze(-1)
|
||||
|
||||
return values, {"values": values.clone().detach()}
|
||||
|
||||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
||||
@DistProfiler.annotate(color="cyan")
|
||||
def compute_values(self, data: DataProto):
|
||||
# Support all hardwares
|
||||
data = data.to(get_device_id())
|
||||
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
|
||||
data.meta_info["micro_batch_size"] = micro_batch_size
|
||||
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
|
||||
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
|
||||
|
||||
with self.engine.eval_mode():
|
||||
data = self.engine.shard_data(data=data)
|
||||
output = self.engine.infer_batch(data, post_fn=self._post_fn_values)
|
||||
response_mask = data.batch["response_mask"]
|
||||
values = output["values"] * response_mask # Only action tokens have values
|
||||
output = DataProto.from_dict(tensors={"values": values})
|
||||
|
||||
output = self.engine.unshard_data(data=output)
|
||||
output = output.to("cpu")
|
||||
return output
|
||||
|
||||
def loss_fn(
|
||||
self, batch: DataProto, vpreds: dict[str, torch.Tensor]
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
old_values = batch["values"]
|
||||
returns = batch["returns"]
|
||||
response_mask = batch["response_mask"]
|
||||
micro_batch_metrics = {}
|
||||
|
||||
values, _ = self._post_fn_values(batch, vpreds)
|
||||
|
||||
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
|
||||
vpreds=values,
|
||||
values=old_values,
|
||||
returns=returns,
|
||||
response_mask=response_mask,
|
||||
cliprange_value=self.config.cliprange_value,
|
||||
loss_agg_mode=self.config.loss_agg_mode,
|
||||
)
|
||||
if self.config.use_dynamic_bsz:
|
||||
# relative to the dynamic bsz
|
||||
loss = vf_loss * (len(batch) / self.config.ppo_mini_batch_size)
|
||||
else:
|
||||
gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
|
||||
loss = vf_loss / gradient_accumulation
|
||||
|
||||
micro_batch_metrics = {
|
||||
"critic/vf_loss": vf_loss.detach().item(),
|
||||
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
|
||||
"critic/vpred_mean": masked_mean(values, response_mask).detach().item(),
|
||||
}
|
||||
|
||||
return loss, micro_batch_metrics
|
||||
|
||||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
||||
@DistProfiler.annotate(color="pink")
|
||||
def update_critic(self, data: DataProto):
|
||||
metrics = {}
|
||||
# Support all hardwares
|
||||
data = data.to(get_device_id())
|
||||
# perform forward computation
|
||||
with self.engine.train_mode():
|
||||
data = self.engine.shard_data(data=data)
|
||||
|
||||
with Timer(name="update_critic", logger=None) as timer:
|
||||
select_keys = [
|
||||
"input_ids",
|
||||
"responses",
|
||||
"response_mask",
|
||||
"attention_mask",
|
||||
"position_ids",
|
||||
"values",
|
||||
"returns",
|
||||
]
|
||||
batch = data.select(batch_keys=select_keys).batch
|
||||
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
|
||||
# Split to make minibatch iterator for updating the actor
|
||||
# See PPO paper for details. https://arxiv.org/abs/1707.06347
|
||||
if has_multi_modal_inputs:
|
||||
num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
|
||||
non_tensor_select_keys = ["multi_modal_inputs"]
|
||||
dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
|
||||
else:
|
||||
dataloader = batch.split(self.config.ppo_mini_batch_size)
|
||||
|
||||
for epoch in range(self.config.ppo_epochs):
|
||||
for batch_idx, mini_batch in enumerate(dataloader):
|
||||
self.engine.optimizer_zero_grad()
|
||||
mini_batch_metrics = self.engine.train_batch(mini_batch, self.loss_fn)
|
||||
grad_norm = self.engine.optimizer_step()
|
||||
mini_batch_metrics["critic/grad_norm"] = grad_norm.detach().item()
|
||||
append_to_dict(metrics, mini_batch_metrics)
|
||||
self.engine.optimizer_zero_grad()
|
||||
delta_time = timer.last
|
||||
|
||||
# TODO: should not access engine's flops_counter
|
||||
global_num_tokens = data.meta_info["global_token_num"]
|
||||
estimated_flops, promised_flops = self.engine.flops_counter.estimate_flops(global_num_tokens, delta_time)
|
||||
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
|
||||
|
||||
metrics["critic/lr"] = self.engine.lr_scheduler_step()[0]
|
||||
output = DataProto(batch=None, meta_info={"metrics": metrics})
|
||||
output = self.engine.unshard_data(data=output)
|
||||
|
||||
output = output.to("cpu")
|
||||
return output
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
||||
self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep)
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
||||
self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load)
|
Reference in New Issue
Block a user