Files
verl/docs/workers/fsdp_workers.rst
2024-10-31 14:29:44 +08:00

142 lines
4.1 KiB
ReStructuredText
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

PyTorch FSDP Backend
============
We support PyTorch FSDP Backend by implementing various workers for
actor, critic, reference, rollout and reward models. We also implement
the ``FSDPVLLMShardingManager`` that reshard weight between FSDP and
vLLM in `fsdp_vllm.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/hybrid_engine/fsdp_vllm.py>`_.
**Pros**
- Readily support various models.
- Users only need to implement the corresponding
``dtensor_weight_loader`` for weight synchronization between FSDP
and vLLM. While for ``hf_weight_loader``, users can directly apply
any models supported both in HF and vLLM without any code change.
- Easy to organize the forward and backward computation for each model.
**Cons**
- Poor scalability when it comes to large-scale models (e.g. Llama 70B
and 405B)
- The resharding overhead between actor and rollout could be larger than
Megatron-LM backend.
Due to the simplicity, we recommend using FSDP backend for algorithm
research and prototyping.
FSDP Workers
------------
ActorRolloutRefWorker
^^^^^^^^^^^^^^^^^^^^^
Actor/Rollout HybridEngine
''''''''''''''''''''''''''
1. HybridEngine, Actor and Rollout initialization API.
.. code:: python
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
``ONE_TO_ALL``: when calling the ``init_model`` function from the driver
process, each worker (on a GPU) will execute the following model
initialization process.
The initialization details of HybridEngine, Actor and Rollout are
highlighted below:
1. ``DataParallelPPOActor`` implements the simple PPO computation logics
when the model is built with FSDP, including compute log prob, model
update.
2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM
Engine and make it executed under SPMD to fit into our
``WorkerGroup`` design.
3. ``FSDPVLLMShardingManager`` a context manager to perform actual
resharding between actor and rollout.
See `source code <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/workers/fsdp_workers.py#L42>`_. for more information.
1. Generate sequence and recompute log prob
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
- ``Dispatch.DP_COMPUTE_PROTO``: The data will be dispatched and
collected along the DP dimension
- In this function, the rollout model will perform auto-regressive
generation and the actor model will recompute the old log prob for the
generetad response.
3. Update actor model
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
- Update the actor model weight using PPO & entropy loss.
ReferenceModel
''''''''''''''
1. Reference model initialization
The reference model is initialized using the same function as the actor
model without initializing the HybridEngine and Optimizer. Then the
actor model is also wrapped by the ``DataParallelPPOActor``.
2. Compute reference log prob
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
- In this function, the reference model will call the compute log prob
function in ``DataParallelPPOActor`` to compute the reference log
prob.
CriticWorker and RewardWorker
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. Model initialization
Quite similar to reference model. The CriticWorker will perform
additional initialization for the Optimizer.
2. Compute Values for CriticWorker
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
3. Update Critic
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
4. Compute Reward
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
HybridShard
------------
We didnt support FSDP `HybridShard`. To support this, we may need to
construct a 2D device mesh and test the corresponding
``dtensor_weight_loader`` and ``hf_weight_loader`` for each model.