Files
verl/docs/advance/fsdp_extension.rst
2024-10-31 14:29:44 +08:00

94 lines
4.2 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

Add models to FSDP backend
===========================
Model
--------------------------
In principle, our FSDP backend can support any HF model and we can
sychronoize the actor model weight with vLLM using `hf_weight_loader.py <https://github.com/volcengine/verl/blob/main/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py>`_.
However, ``hf_weight_loader`` is will gather the full state_dict of a
model during synchronization, which may cause OOM. We suggest using
``dtensor_weight_loader`` which gather the full model parameter layer by
layer to reduce the peak memory usage. We already support dtensor weight
loader for the models below in `dtensor_weight_loader.py <https://github.com/volcengine/verl/blob/main/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loader.py>`_.:
- ``GPT2LMHeadModel``
- ``LlamaForCausalLM``
- ``LLaMAForCausalLM``
- ``MistralForCausalLM``
- ``InternLMForCausalLM``
- ``AquilaModel``
- ``AquilaForCausalLM``
- ``Phi3ForCausalLM``
- ``GemmaForCausalLM``
- ``Gemma2ForCausalLM``
- ``GPTBigCodeForCausalLM``
- ``Starcoder2ForCausalLM``
- ``Qwen2ForCausalLM``
- ``DeepseekV2ForCausalLM``
To implement ``dtensor_weight_loader`` of a model thats supported in
vLLM, follow the guide of gemma model below:
1. Copy the
``load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])`` from the vllm model class
to ``dtensor_weight_loaders.py``
2. Modify the arguments to
``(actor_weights: Dict, vllm_model: nn.Module)``
3. Replace the ``self`` to ``vllm_model``
4. Add the
``local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)``
before each ``param = params_dict[name]`` and modify the following
weight loading using ``local_loaded_weight``.
5. Register the implemented dtensor weight loader to ``__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__``.
.. code-block:: diff
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
- params_dict = dict(self.named_parameters())
+ params_dict = dict(vllm_model.named_parameters())
loaded_params = set()
- for name, loaded_weight in weights:
+ for name, loaded_weight in actor_weights.items():
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
+ weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")