mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
# [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: https://github.com/volcengine/verl/issues/1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
43 lines
1.8 KiB
Bash
43 lines
1.8 KiB
Bash
set -x
|
|
|
|
python3 -m verl.trainer.main_ppo \
|
|
algorithm.adv_estimator=gae \
|
|
data.train_files=$HOME/data/gsm8k/train.parquet \
|
|
data.val_files=$HOME/data/gsm8k/test.parquet \
|
|
data.train_batch_size=1024 \
|
|
data.max_prompt_length=512 \
|
|
data.max_response_length=512 \
|
|
data.filter_overlong_prompts=True \
|
|
data.truncation='error' \
|
|
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
|
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
|
actor_rollout_ref.model.use_remove_padding=True \
|
|
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
|
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
|
|
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
|
actor_rollout_ref.actor.use_kl_loss=False \
|
|
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
|
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
|
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
|
|
actor_rollout_ref.rollout.name=vllm \
|
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
|
|
critic.optim.lr=1e-5 \
|
|
critic.model.use_remove_padding=True \
|
|
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
|
|
critic.model.enable_gradient_checkpointing=True \
|
|
critic.ppo_micro_batch_size_per_gpu=32 \
|
|
critic.model.fsdp_config.param_offload=False \
|
|
critic.model.fsdp_config.optimizer_offload=False \
|
|
algorithm.use_kl_in_reward=False \
|
|
trainer.critic_warmup=0 \
|
|
trainer.logger='["console","wandb"]' \
|
|
trainer.project_name='verl_example_gsm8k' \
|
|
trainer.experiment_name='deepseek_llm_7b_function_rm' \
|
|
trainer.n_gpus_per_node=8 \
|
|
trainer.nnodes=1 \
|
|
trainer.save_freq=20 \
|
|
trainer.test_freq=1 \
|
|
trainer.use_legacy_worker_impl=auto \
|
|
trainer.total_epochs=15 $@
|