[trainer] refactor: Training Engine Interface and Development Plan (#1977)

# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
https://github.com/volcengine/verl/issues/1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
This commit is contained in:
Ziheng Jiang
2025-07-17 22:05:21 -07:00
committed by GitHub
parent 223caf7022
commit 9d7cba4e12
17 changed files with 1538 additions and 2 deletions

View File

@ -101,10 +101,15 @@ jobs:
ray stop --force
python3 examples/data_preprocess/gsm8k.py
# HF sanity
- name: Running GSM8K E2E training tests on 1 L20 GPU with hf for santiy
- name: Running GSM8K E2E training tests on 1 L20 GPU with hf for sanity
run: |
ray stop --force
bash tests/special_e2e/ppo_trainer/run_single_gpu.sh
# HF sanity
- name: Running GSM8K E2E training tests on 1 L20 GPU with engine interface for sanity.
run: |
ray stop --force
bash tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
# Function RM
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8)
run: |

View File

@ -38,4 +38,5 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=1 \
trainer.use_legacy_worker_impl=auto \
trainer.total_epochs=15 $@

View File

@ -0,0 +1,25 @@
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['console'] \
trainer.val_before_train=False \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
actor_rollout_ref.rollout.name=hf \
trainer.use_legacy_worker_impl=disable \
trainer.total_training_steps=2

View File

@ -36,6 +36,7 @@ CUDA_KEYWORD_CHECK_WHITELIST = [
"verl/trainer/ppo/ray_trainer.py", # appear in default device_name
"verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type
"verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name
"verl/workers/engine/fsdp/engine_impl.py",
]
# directory or file path must contain keyword "nccl"

View File

@ -214,6 +214,7 @@ trainer:
max_critic_ckpt_to_keep: null
ray_wait_register_center_timeout: 300
device: cuda
use_legacy_worker_impl: auto
data:
tokenizer: null
use_shm: false

View File

@ -322,6 +322,10 @@ trainer:
# Device to run training on (e.g., "cuda", "cpu")
device: cuda
# whether to use legacy worker implementation
# mode: "auto", "enable", or "disable"
use_legacy_worker_impl: auto
# configs related to ray initialization
ray_init:

View File

@ -127,7 +127,20 @@ class TaskRunner:
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
if use_legacy_worker_impl in ["auto", "enable"]:
# import warnings
# warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
# Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.")
from verl.workers.fsdp_workers import CriticWorker
elif use_legacy_worker_impl == "disable":
from verl.workers.roles import CriticWorker
print("Using new worker implementation")
else:
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
actor_rollout_cls = (
AsyncActorRolloutRefWorker

View File

@ -0,0 +1,17 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BaseEngine, EngineRegistry
from .fsdp import FSDPEngine
__all__ = ["BaseEngine", "EngineRegistry", "FSDPEngine"]

235
verl/workers/engine/base.py Normal file
View File

@ -0,0 +1,235 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The abstract base class defining the interface for model training engines.
"""
from typing import Callable
import torch
from verl import DataProto
class BaseEngine:
"""
Abstract base class defining the interface for model training engines.
Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
"""
def __init__(self, config):
"""
Initialize the BaseEngine.
Args:
config: Configuration object containing parameters for engine setup.
"""
raise NotImplementedError
def init_model(self):
"""
Instantiate or load the model, optimizer, and learning rate scheduler.
Should prepare all components necessary for training or evaluation.
"""
raise NotImplementedError
def train_mode(self):
"""
Context manager entry for switching the engine and model into training mode.
Usage:
with engine.train_mode():
# runs in training mode
"""
raise NotImplementedError
def eval_mode(self):
"""
Context manager entry for switching the engine and model into evaluation mode.
Usage:
with engine.eval_mode():
# runs in evaluation mode
"""
raise NotImplementedError
def infer_batch(
self,
data: DataProto,
post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
) -> dict[str, torch.Tensor]:
"""
Perform inference on a mini batch of data.
Args:
data: The input data for inference, typically containing tensors and metadata.
post_fn: A post-processing function that takes a micro-batch and predictions as input,
and returns a tuple containing processed predictions and a dictionary of outputs.
Returns:
dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.
"""
raise NotImplementedError
def train_batch(
self,
data: DataProto,
loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
) -> dict[str, torch.Tensor]:
"""
Perform a training step on a mini-batch of data.
Args:
data (DataProto): The input data for training, typically containing tensors and metadata.
loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.
Returns:
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.
"""
raise NotImplementedError
def optimizer_zero_grad(self):
"""
Zero out gradients of all parameters before starting a new backward pass.
"""
raise NotImplementedError
def optimizer_step(self):
"""
Perform an optimization step to update model parameters based on accumulated gradients.
Returns:
grad_norm (float): The norm of the gradients before clipping or update.
"""
raise NotImplementedError
def lr_scheduler_step(self):
"""
Advance the learning rate scheduler by one step.
Returns:
current_lr (float or list[float]): Updated learning rate(s).
"""
raise NotImplementedError
def shard_data(self, data):
"""
Shard or partition data for distributed training or parallel execution.
Args:
data: Data structure to be sharded across devices/workers.
Returns:
Sharded data in the same format as input.
"""
raise NotImplementedError
def unshard_data(self, data):
"""
Reconstruct or gather sharded data back to a unified format.
Args:
data: Sharded data structure to reconstruct.
Returns:
Unsharded, combined data.
"""
raise NotImplementedError
def to(self, device: str, model: bool = True, optimizer: bool = True):
"""
Move model parameters, optimizer states, or both to the specified device.
Args:
device: Target device identifier.
model: If True, move the model.
optimizer: If True, move the optimizer states.
"""
raise NotImplementedError
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
"""
Save model, optimizer, and scheduler states to a checkpoint.
Args:
local_path: Local filesystem path to save checkpoint.
hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
"""
raise NotImplementedError
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
"""
Load model, optimizer, and scheduler states from a checkpoint.
Args:
local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading.
"""
raise NotImplementedError
class EngineRegistry:
"""
A registry for managing and instantiating different types of training engines.
This class uses a dictionary to store engine classes, mapping a string key to each class.
It provides a decorator `register` to add new engines to the registry and a `new` method
to create an instance of a registered engine.
"""
_engines = {}
@classmethod
def register(cls, key):
"""
A class method decorator that registers an engine class with a given key.
This allows for dynamic instantiation of engine classes by their registered key.
Args:
key (str): The identifier to associate with the engine class.
Returns:
A decorator function that takes an engine class and registers it.
"""
def decorator(engine_class):
assert issubclass(engine_class, BaseEngine)
cls._engines[key] = engine_class
return engine_class
return decorator
@classmethod
def new(cls, key, *args, **kwargs):
"""
Function to create a new training engine instance based on the provided config.
Args:
key: A configuration object containing the engine key and other settings.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
engine: An instance of the training engine corresponding to the config.
Raises:
NotImplementedError: If the engine key in the config does not match any known engines.
"""
if key in cls._engines:
return cls._engines[key](*args, **kwargs)
else:
raise NotImplementedError(f"Unknown engine: {key}")

View File

@ -0,0 +1,16 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .engine_impl import FSDPEngine
__all__ = ["FSDPEngine"]

View File

@ -0,0 +1,727 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP)
"""
import gc
import itertools
import logging
import os
import warnings
from typing import Callable
import torch
import torch.distributed
from omegaconf import OmegaConf
from peft import LoraConfig, TaskType, get_peft_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import (
get_device_id,
get_device_name,
get_torch_device,
is_cuda_available,
is_npu_available,
)
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
apply_fsdp2,
fsdp2_clip_grad_norm_,
fsdp2_load_full_state_dict,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.py_functional import append_to_dict, convert_to_regular_types
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input
from ..base import BaseEngine, EngineRegistry
from .utils import create_device_mesh, get_sharding_strategy
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
device_name = get_device_name()
@EngineRegistry.register("fsdp")
class FSDPEngine(BaseEngine):
"""
Concrete Engine implementation using PyTorch FullyShardedDataParallel (FSDP).
Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.
"""
def __init__(self, config):
"""
Initialize the FSDPEngine.
Sets up distributed device meshes, LoRA, and offload policies based on config.
Args:
config: Configuration object with FSDP and model settings.
"""
self.config = config
self.rank = torch.distributed.get_rank()
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.use_remove_padding = config.model.get("use_remove_padding", False)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.ppo_mini_batch_size *= self.config.rollout_n
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
if self.config.ppo_micro_batch_size is not None:
self.config.ppo_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.forward_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
if self.config.ppo_micro_batch_size_per_gpu is not None:
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, (
f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by "
f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
)
assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, (
f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than "
f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
)
self._is_lora = self.config.model.get("lora_rank", 0) > 0
def init_model(self):
"""
Build the model, optimizer, and learning rate scheduler under FSDP.
Applies device, dtype, and precision configurations, including mixed precision.
Sets up checkpoint manager and FLOPs counter.
"""
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
self.module, self.optimizer, self.lr_scheduler = self._build_model_optimizer(self.config)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After offload model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After offload optimizer during init", logger=logger)
self.flops_counter = FlopsCounter(self.model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.module,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.checkpoint,
)
def _build_model_optimizer(self, config):
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import load_valuehead_model, print_model_size
from verl.utils.torch_dtypes import PrecisionType
use_shm = config.model.get("use_shm", False)
local_path = copy_to_local(config.model.path, use_shm=use_shm)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Engine overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained(
local_path,
attn_implementation="flash_attention_2",
trust_remote_code=config.model.get("trust_remote_code", False),
)
model_config.num_labels = 1
# patch for kimi-vl
if getattr(model_config, "model_type", None) == "kimi_vl":
model_config.text_config.topk_method = "greedy"
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
model_config.classifier_dropout = 0.0
model_config.hidden_dropout = "0"
model_config.summary_dropout_prob = 0.0
module = load_valuehead_model(
local_path,
torch_dtype,
model_config,
config.model.get("trust_remote_code", False),
)
apply_monkey_patch(
model=module,
use_remove_padding=self.use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
)
# some parameters may not in torch_dtype
module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if self._is_lora:
print("Applying LoRA to the module")
module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
"task_type": TaskType.CAUSAL_LM,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"bias": "none",
}
module = get_peft_model(module, LoraConfig(**lora_config))
if self.rank == 0:
print_model_size(module)
self.model_config = model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(
module=module,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get("lora_rank", 0) > 0,
)
log_gpu_memory_usage("Before FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation
if config.strategy == "fsdp":
module = FSDP(
module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
device_mesh=self.device_mesh,
cpu_offload=None,
)
elif config.strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
offload_policy = None
if fsdp_config.offload_policy:
self._is_offload_param = False
self._is_offload_optimizer = False
offload_policy = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = module.state_dict()
apply_fsdp2(module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)
else:
raise NotImplementedError(f"Unknown strategy {config.strategy}")
if config.model.get("enable_activation_offload", False):
enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False)
enable_activation_offloading(module, config.strategy, enable_gradient_checkpointing)
log_gpu_memory_usage("After FSDP", logger=None)
optimizer = optim.AdamW(
module.parameters(),
lr=config.optim.lr,
betas=config.optim.get("betas", (0.9, 0.999)),
weight_decay=config.optim.get("weight_decay", 1e-2),
)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
warmup_style = config.optim.get("warmup_style", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
if warmup_style == "constant":
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
elif warmup_style == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
return module, optimizer, lr_scheduler
def train_mode(self):
"""
Return a context manager that switches to training mode with FSDP-specific handling.
Includes parameter and optimizer offload entry/exit.
"""
return EngineTrainModeCtx(self)
def eval_mode(self):
"""
Return a context manager that switches to evaluation mode with FSDP-specific handling.
Includes activation offload entry/exit.
"""
return EngineEvalModeCtx(self)
def shard_data(self, data):
"""
Preprocess data into sharded format via UlyssesShardingManager.
"""
return self.ulysses_sharding_manager.preprocess_data(data)
def unshard_data(self, data):
"""
Postprocess data from sharded format back to full format.
"""
return self.ulysses_sharding_manager.postprocess_data(data)
def get_default_ctx(self):
use_value_head_model = hasattr(self.module, "v_head")
ctx = {
"use_value_head_model": use_value_head_model,
"ulysses_sequence_parallel_size": self.ulysses_sequence_parallel_size,
}
return ctx
def _forward_micro_batch(self, micro_batch):
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch.keys():
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
with torch.autocast(device_type=device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1)
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)
# only pass input_ids and position_ids to enable flash_attn_varlen
preds = self.module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
if hasattr(self.module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
preds_rmpad = preds[2].squeeze(0).unsqueeze(-1)
else:
preds_rmpad = preds.logits
preds_rmpad = preds_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
preds_rmpad = gather_outpus_and_unpad(preds_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad it back
preds = pad_input(preds_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
else:
preds = self.module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
if hasattr(self.module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
preds = preds[2]
else:
preds = preds.logits
return preds
def infer_batch(
self,
data: DataProto,
post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
) -> dict[str, torch.Tensor]:
"""
Perform inference on a mini batch of data.
Args:
data: The input data for inference, typically containing tensors and metadata.
post_fn: A post-processing function that takes a micro-batch and predictions as input,
and returns a tuple containing processed predictions and a dictionary of outputs.
Returns:
dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.
"""
assert self.mode == "eval"
micro_batch_size = data.meta_info["micro_batch_size"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
batch = data.select(batch_keys=select_keys).batch
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if has_multi_modal_inputs:
num_micro_batches = data.batch.batch_size[0] // micro_batch_size
non_tensor_select_keys = ["multi_modal_inputs"]
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
preds_list = {}
for micro_batch in micro_batches:
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
with torch.no_grad():
# micro_batch_preds would be a dict[str, torch.Tensor]
preds = self._forward_micro_batch(micro_batch)
_, outputs = post_fn(micro_batch, preds)
assert isinstance(outputs, dict)
# append micro batch preds to dict[str, List[torch.Tensor]]
append_to_dict(preds_list, outputs)
# reorganize mini batch preds from
# dict[str, List[torch.Tensor]] to dict[str, torch.Tensor]
mini_batch_preds = {}
for key, t_list in preds_list.items():
t_concat = torch.concat(t_list, dim=0)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == t_concat.size(0), f"{len(indices)} vs. {t_concat.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
t_concat = t_concat[revert_indices]
mini_batch_preds[key] = t_concat
return mini_batch_preds
def train_batch(
self,
data: DataProto,
loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
) -> dict[str, torch.Tensor]:
"""
Perform a training step on a mini-batch of data.
Args:
data (DataProto): The input data for training, typically containing tensors and metadata.
loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.
Returns:
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.
"""
assert self.mode == "train"
# split batch into micro_batches
mini_batch = data
select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids"]
if "multi_modal_inputs" in mini_batch:
non_tensor_select_keys = ["multi_modal_inputs"]
num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
micro_batches = mini_batch.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
elif self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
mini_batch_metrics = {}
for micro_batch in micro_batches:
# Support all devices
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch.to(get_device_id()), **micro_batch.non_tensor_batch}
else:
micro_batch = micro_batch.to(get_device_id()) # critic device is cpu when using offload
preds = self._forward_micro_batch(micro_batch)
loss, micro_batch_metrics = loss_fn(micro_batch, preds)
append_to_dict(mini_batch_metrics, micro_batch_metrics)
loss.backward()
return mini_batch_metrics
def optimizer_zero_grad(self):
"""
Zero gradients and enforce FSDP grad-clipping logic.
"""
self.optimizer.zero_grad()
def optimizer_step(self):
"""
Clip gradients, skip update if non-finite, and step optimizer.
Returns:
grad_norm (float): Norm of gradients before clipping.
"""
assert self.config.grad_clip is not None
if isinstance(self.module, FSDP):
grad_norm = self.module.clip_grad_norm_(self.config.grad_clip)
elif isinstance(self.module, FSDPModule):
grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm=self.config.grad_clip)
# if grad_norm is not finite, skip the update
if not torch.isfinite(grad_norm):
print(f"WARN: grad_norm is not finite: {grad_norm}")
self.optimizer.zero_grad()
else:
self.optimizer.step()
return grad_norm
def lr_scheduler_step(self):
"""
Advance FSDP scheduler and return updated learning rate.
"""
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()
return lr
def to(self, device: str, model: bool = True, optimizer: bool = True):
"""
Move FSDP model and/or optimizer to CPU or GPU with offload support.
"""
assert device in ("cuda", "cpu")
if device == "cuda":
if not self.config.model.fsdp_config.param_offload:
if model:
load_fsdp_model_to_gpu(self.model_module)
if optimizer and self.optimizer is not None:
load_fsdp_optimizer(self.optimizer, device)
gc.collect()
elif device == "cpu":
if not self.config.model.fsdp_config.param_offload:
if model:
offload_fsdp_model_to_cpu(self.model_module)
if optimizer and self.optimizer is not None:
offload_fsdp_optimizer(self.optimizer)
else:
raise ValueError(f"Invalid device type: {device}")
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
"""
Save FSDP checkpoint, handling parameter offload as needed.
"""
if self._is_offload_param:
load_fsdp_model_to_gpu(self.module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
"""
Load FSDP checkpoint, restoring parameters and optimizer state.
"""
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.module)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.optimizer)
class EngineEvalModeCtx:
def __init__(self, engine):
self.engine = engine
def __enter__(self):
self.engine.mode = "eval"
if self.engine._is_offload_param:
load_fsdp_model_to_gpu(self.engine.module)
self.engine.ulysses_sharding_manager.__enter__()
self.engine.module.eval()
def __exit__(self, exc_type, exc_value, traceback):
self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
if self.engine._is_offload_param:
offload_fsdp_model_to_cpu(self.engine.module)
self.engine.mode = None
class EngineTrainModeCtx:
def __init__(self, engine):
self.engine = engine
def __enter__(self):
self.engine.mode = "train"
if self.engine._is_offload_param:
load_fsdp_model_to_gpu(self.engine.module)
if self.engine._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.engine.optimizer, device_id=get_torch_device().current_device())
self.engine.ulysses_sharding_manager.__enter__()
self.engine.module.train()
def __exit__(self, exc_type, exc_value, traceback):
self.engine.ulysses_sharding_manager.__exit__(exc_type, exc_value, traceback)
if self.engine._is_offload_param:
offload_fsdp_model_to_cpu(self.engine.module)
if self.engine._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.optimizer)
self.engine.mode = None

View File

@ -0,0 +1,61 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.distributed.device_mesh import init_device_mesh
from verl.utils.device import get_device_name
def create_device_mesh(world_size, fsdp_size):
"""
Create a device mesh for distributed training based on the world size and FSDP size.
Args:
world_size (int): Total number of processes in the distributed training setup.
fsdp_size (int): Size of the Fully Sharded Data Parallel (FSDP) group.
Returns:
torch.distributed.device_mesh.DeviceMesh: The initialized device mesh.
"""
device_name = get_device_name()
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
device_mesh = init_device_mesh(
device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)
return device_mesh
def get_sharding_strategy(device_mesh):
"""
Determine the appropriate sharding strategy based on the number of dimensions of the device mesh.
Args:
device_mesh (torch.distributed.device_mesh.DeviceMesh): The device mesh used for distributed training.
Returns:
torch.distributed.fsdp.ShardingStrategy: The sharding strategy to be used with FSDP.
Raises:
NotImplementedError: If the number of dimensions of the device mesh is neither 1 nor 2.
"""
from torch.distributed.fsdp import ShardingStrategy
if device_mesh.ndim == 1:
sharding_strategy = ShardingStrategy.FULL_SHARD
elif device_mesh.ndim == 2:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
return sharding_strategy

View File

@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,166 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
import torch
from verl import DataProto
from ..base import BaseEngine, EngineRegistry
@EngineRegistry.register("megatron")
class MegatronEngine(BaseEngine):
def __init__(self, config):
raise NotImplementedError
def init_model(self):
raise NotImplementedError
def train_mode(self):
"""
Context manager entry for switching the engine and model into training mode.
Usage:
with engine.train_mode():
# runs in training mode
"""
raise NotImplementedError
def eval_mode(self):
"""
Context manager entry for switching the engine and model into evaluation mode.
Usage:
with engine.eval_mode():
# runs in evaluation mode
"""
raise NotImplementedError
def infer_batch(
self,
data: DataProto,
post_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
) -> dict[str, torch.Tensor]:
"""
Perform inference on a mini batch of data.
Args:
data: The input data for inference, typically containing tensors and metadata.
post_fn: A post-processing function that takes a micro-batch and predictions as input,
and returns a tuple containing processed predictions and a dictionary of outputs.
Returns:
dict[str, torch.Tensor]: A dictionary containing the predictions for the entire batch.
"""
raise NotImplementedError
def train_batch(
self,
data: DataProto,
loss_fn: Callable[[DataProto, torch.Tensor], tuple[torch.Tensor, dict[str, torch.Tensor]]],
) -> dict[str, torch.Tensor]:
"""
Perform a training step on a mini-batch of data.
Args:
data (DataProto): The input data for training, typically containing tensors and metadata.
loss_fn (Callable): A function that computes the loss and metrics given a micro-batch and predictions.
Returns:
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the mini-batch.
"""
raise NotImplementedError
def optimizer_zero_grad(self):
"""
Zero out gradients of all parameters before starting a new backward pass.
"""
raise NotImplementedError
def optimizer_step(self):
"""
Perform an optimization step to update model parameters based on accumulated gradients.
Returns:
grad_norm (float): The norm of the gradients before clipping or update.
"""
raise NotImplementedError
def lr_scheduler_step(self):
"""
Advance the learning rate scheduler by one step.
Returns:
current_lr (float or list[float]): Updated learning rate(s).
"""
raise NotImplementedError
def shard_data(self, data):
"""
Shard or partition data for distributed training or parallel execution.
Args:
data: Data structure to be sharded across devices/workers.
Returns:
Sharded data in the same format as input.
"""
raise NotImplementedError
def unshard_data(self, data):
"""
Reconstruct or gather sharded data back to a unified format.
Args:
data: Sharded data structure to reconstruct.
Returns:
Unsharded, combined data.
"""
raise NotImplementedError
def to(self, device: str, model: bool = True, optimizer: bool = True):
"""
Move model parameters, optimizer states, or both to the specified device.
Args:
device: Target device identifier.
model: If True, move the model.
optimizer: If True, move the optimizer states.
"""
raise NotImplementedError
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
"""
Save model, optimizer, and scheduler states to a checkpoint.
Args:
local_path: Local filesystem path to save checkpoint.
hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
"""
raise NotImplementedError
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
"""
Load model, optimizer, and scheduler states from a checkpoint.
Args:
local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading.
"""
raise NotImplementedError

View File

@ -0,0 +1,17 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .critic import CriticWorker
__all__ = ["CriticWorker"]

View File

@ -0,0 +1,51 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
class ActorWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config):
raise NotImplementedError
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
raise NotImplementedError
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
raise NotImplementedError
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
raise NotImplementedError
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
raise NotImplementedError
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
raise NotImplementedError
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
raise NotImplementedError

View File

@ -0,0 +1,183 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import logging
import os
import torch
from codetiming import Timer
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.trainer.ppo import core_algos
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import (
get_device_id,
get_nccl_backend,
)
from verl.utils.profiler import DistProfiler, DistProfilerExtension
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.workers.engine import EngineRegistry
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class CriticWorker(Worker, DistProfilerExtension):
def __init__(self, config):
Worker.__init__(self)
DistProfilerExtension.__init__(
self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")))
)
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=get_nccl_backend())
self.config = config
self.engine = EngineRegistry.new(self.config.strategy, self.config)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
self.engine.init_model()
def _post_fn_values(self, micro_batch, preds):
response_length = micro_batch["responses"].size(-1)
values = preds[:, -response_length - 1 : -1]
use_remove_padding = self.config.model.get("use_remove_padding", False)
if not use_remove_padding:
values = values.squeeze(-1)
return values, {"values": values.clone().detach()}
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="cyan")
def compute_values(self, data: DataProto):
# Support all hardwares
data = data.to(get_device_id())
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
with self.engine.eval_mode():
data = self.engine.shard_data(data=data)
output = self.engine.infer_batch(data, post_fn=self._post_fn_values)
response_mask = data.batch["response_mask"]
values = output["values"] * response_mask # Only action tokens have values
output = DataProto.from_dict(tensors={"values": values})
output = self.engine.unshard_data(data=output)
output = output.to("cpu")
return output
def loss_fn(
self, batch: DataProto, vpreds: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
old_values = batch["values"]
returns = batch["returns"]
response_mask = batch["response_mask"]
micro_batch_metrics = {}
values, _ = self._post_fn_values(batch, vpreds)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=values,
values=old_values,
returns=returns,
response_mask=response_mask,
cliprange_value=self.config.cliprange_value,
loss_agg_mode=self.config.loss_agg_mode,
)
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = vf_loss * (len(batch) / self.config.ppo_mini_batch_size)
else:
gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
loss = vf_loss / gradient_accumulation
micro_batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(values, response_mask).detach().item(),
}
return loss, micro_batch_metrics
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@DistProfiler.annotate(color="pink")
def update_critic(self, data: DataProto):
metrics = {}
# Support all hardwares
data = data.to(get_device_id())
# perform forward computation
with self.engine.train_mode():
data = self.engine.shard_data(data=data)
with Timer(name="update_critic", logger=None) as timer:
select_keys = [
"input_ids",
"responses",
"response_mask",
"attention_mask",
"position_ids",
"values",
"returns",
]
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
if has_multi_modal_inputs:
num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
non_tensor_select_keys = ["multi_modal_inputs"]
dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)
for epoch in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(dataloader):
self.engine.optimizer_zero_grad()
mini_batch_metrics = self.engine.train_batch(mini_batch, self.loss_fn)
grad_norm = self.engine.optimizer_step()
mini_batch_metrics["critic/grad_norm"] = grad_norm.detach().item()
append_to_dict(metrics, mini_batch_metrics)
self.engine.optimizer_zero_grad()
delta_time = timer.last
# TODO: should not access engine's flops_counter
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.engine.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
metrics["critic/lr"] = self.engine.lr_scheduler_step()[0]
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.engine.unshard_data(data=output)
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load)