Files
verl/examples/ppo_trainer/run_deepseek7b_llm.sh
Ziheng Jiang 9d7cba4e12 [trainer] refactor: Training Engine Interface and Development Plan (#1977)
# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
https://github.com/volcengine/verl/issues/1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
2025-07-17 22:05:21 -07:00

43 lines
1.8 KiB
Bash

set -x
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=gae \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
critic.model.enable_gradient_checkpointing=True \
critic.ppo_micro_batch_size_per_gpu=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='deepseek_llm_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=1 \
trainer.use_legacy_worker_impl=auto \
trainer.total_epochs=15 $@