[docs] Manual tp-plan (#41674)

* manual tp-plan

* feedback
This commit is contained in:
Steven Liu
2025-10-17 09:38:26 -07:00
committed by GitHub
parent 347a0f9e83
commit e7592f2508

View File

@ -45,7 +45,13 @@ This guide shows how to enable tensor parallelism with Transformers and differen
## Partitioning a model
Transformers supports tensor parallelism if a model has a `tp_plan`. Set `tp_plan="auto"` to automatically use a tensor parallelism plan based on a model's predefined configuration.
Transformers supports tensor parallelism if a model has a `tp_plan`. There are two ways to partition a model.
- Set `tp_plan="auto"` to automatically use a tensor parallelism plan based on a model's predefined configuration.
- Define and pass a manual `tp_plan`.
<hfoptions id="tp_plan">
<hfoption id="auto plan">
```py
import os
@ -53,9 +59,7 @@ import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan="auto")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct" , dtype=torch.bfloat16, tp_plan="auto")
print(model._tp_plan)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
@ -72,6 +76,31 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
torchrun --nproc-per-node 4 demo.py
```
</hfoption>
<hfoption id="manual plan">
Define a tensor parallel plan for each layer in `tp_plan` and pass it to [`~PreTrainedModel.from_pretrained`]. The example below uses column and row partitioning. See the [Partitioning strategies](#partitioning-strategies) section for other supported strategies.
Manual partitioning requires deep understanding of model architecture and strategy interactions. Poor partitioning choices create slow models that fail or produce incorrect results. The [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) explains partitioning strategies in detail.
```py
from transformers import AutoModelForCausalLM
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
...
}
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", dtype="auto", tp_plan=tp_plan)
print(model.tp_plan)
```
</hfoption>
</hfoptions>
## Partitioning strategies
All partitioning strategies are defined in the [`ParallelInterface`] class which maps a string to the strategy implementation. You don't need to interact with this class directly since all the strategies are set with `tp_plan` in [`~PreTrainedModel.from_pretrained`], but it is useful for checking what strategies are available.