mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
94 lines
4.2 KiB
ReStructuredText
94 lines
4.2 KiB
ReStructuredText
|
||
Add models to FSDP backend
|
||
===========================
|
||
|
||
Model
|
||
--------------------------
|
||
|
||
In principle, our FSDP backend can support any HF model and we can
|
||
sychronoize the actor model weight with vLLM using `hf_weight_loader.py <https://github.com/volcengine/verl/blob/main/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py>`_.
|
||
However, ``hf_weight_loader`` is will gather the full state_dict of a
|
||
model during synchronization, which may cause OOM. We suggest using
|
||
``dtensor_weight_loader`` which gather the full model parameter layer by
|
||
layer to reduce the peak memory usage. We already support dtensor weight
|
||
loader for the models below in `dtensor_weight_loader.py <https://github.com/volcengine/verl/blob/main/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loader.py>`_.:
|
||
|
||
- ``GPT2LMHeadModel``
|
||
- ``LlamaForCausalLM``
|
||
- ``LLaMAForCausalLM``
|
||
- ``MistralForCausalLM``
|
||
- ``InternLMForCausalLM``
|
||
- ``AquilaModel``
|
||
- ``AquilaForCausalLM``
|
||
- ``Phi3ForCausalLM``
|
||
- ``GemmaForCausalLM``
|
||
- ``Gemma2ForCausalLM``
|
||
- ``GPTBigCodeForCausalLM``
|
||
- ``Starcoder2ForCausalLM``
|
||
- ``Qwen2ForCausalLM``
|
||
- ``DeepseekV2ForCausalLM``
|
||
|
||
To implement ``dtensor_weight_loader`` of a model that’s supported in
|
||
vLLM, follow the guide of gemma model below:
|
||
|
||
1. Copy the
|
||
``load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])`` from the vllm model class
|
||
to ``dtensor_weight_loaders.py``
|
||
2. Modify the arguments to
|
||
``(actor_weights: Dict, vllm_model: nn.Module)``
|
||
3. Replace the ``self`` to ``vllm_model``
|
||
4. Add the
|
||
``local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)``
|
||
before each ``param = params_dict[name]`` and modify the following
|
||
weight loading using ``local_loaded_weight``.
|
||
5. Register the implemented dtensor weight loader to ``__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__``.
|
||
|
||
.. code-block:: diff
|
||
|
||
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||
+ def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
|
||
stacked_params_mapping = [
|
||
# (param_name, shard_name, shard_id)
|
||
("qkv_proj", "q_proj", "q"),
|
||
("qkv_proj", "k_proj", "k"),
|
||
("qkv_proj", "v_proj", "v"),
|
||
("gate_up_proj", "gate_proj", 0),
|
||
("gate_up_proj", "up_proj", 1),
|
||
]
|
||
- params_dict = dict(self.named_parameters())
|
||
+ params_dict = dict(vllm_model.named_parameters())
|
||
loaded_params = set()
|
||
- for name, loaded_weight in weights:
|
||
+ for name, loaded_weight in actor_weights.items():
|
||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||
if shard_name not in name:
|
||
continue
|
||
name = name.replace(shard_name, param_name)
|
||
# Skip loading extra bias for GPTQ models.
|
||
if name.endswith(".bias") and name not in params_dict:
|
||
continue
|
||
+ local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||
param = params_dict[name]
|
||
weight_loader = param.weight_loader
|
||
- weight_loader(param, loaded_weight, shard_id)
|
||
+ weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
|
||
break
|
||
else:
|
||
# lm_head is not used in vllm as it is tied with embed_token.
|
||
# To prevent errors, skip loading lm_head.weight.
|
||
if "lm_head.weight" in name:
|
||
continue
|
||
# Skip loading extra bias for GPTQ models.
|
||
if name.endswith(".bias") and name not in params_dict:
|
||
continue
|
||
+ local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
|
||
param = params_dict[name]
|
||
weight_loader = getattr(param, "weight_loader",
|
||
default_weight_loader)
|
||
weight_loader(param, loaded_weight)
|
||
loaded_params.add(name)
|
||
unloaded_params = params_dict.keys() - loaded_params
|
||
if unloaded_params:
|
||
raise RuntimeError(
|
||
"Some weights are not initialized from checkpoints: "
|
||
f"{unloaded_params}") |