Allow passing tp_plan in from_pretrained directly (#41435)

* start

* allow passing it

* fix plans

* fix

* fix

* style

* style

* fix

* add_test

* oupsi indent

* fix

* fix

* fix for CI without accelerator

* fix import
This commit is contained in:
Cyril Vallez
2025-10-16 11:12:07 +02:00
committed by GitHub
parent 59efd86da2
commit 3ef6f2c415
3 changed files with 89 additions and 69 deletions

View File

@ -38,16 +38,15 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, device_map=None):
def initialize_tensor_parallelism(
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
):
r"""
Sets up the device mesh and initialized the backend for tensor parallelism.
This function is called when the model is loaded and the TP plan is set to 'auto'.
"""
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
if tp_plan is not None and device_map is not None:
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
if device_mesh is None:
@ -80,7 +79,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, devic
except Exception as e:
raise OSError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan='auto'`."
"sure you init torch distributed in your script to use `tp_plan`."
) from e
if device_type != "cpu":
@ -112,7 +111,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, devic
tp_size = device_mesh.size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
return tp_device, device_map, device_mesh, tp_size
return device_map, device_mesh, tp_size
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
@ -1110,13 +1109,16 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
def distribute_model(model, distributed_config, device_mesh, tp_size):
def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
model._tp_size = tp_size
model._device_mesh = device_mesh
if distributed_config is not None:
if isinstance(distributed_config, dict):
distributed_config = DistributedConfig.from_dict(distributed_config)
model.config.distributed_config = distributed_config
# Set the new requested tp_plan on the model
if isinstance(tp_plan, dict):
model.tp_plan = tp_plan
model_plan = model.tp_plan
if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
for v in model_plan.values():

View File

@ -68,6 +68,7 @@ from .integrations.peft import maybe_load_adapters
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_paged import sdpa_attention_paged_forward
from .integrations.tensor_parallel import (
ALL_PARALLEL_STYLES,
_get_parameter_tp_plan,
distribute_model,
initialize_tensor_parallelism,
@ -1883,7 +1884,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
f" {self.__class__.__name__}"
)
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
@ -1909,11 +1912,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return self._pp_plan
@tp_plan.setter
def tp_plan(self, plan: dict[str, str]):
if plan is not None:
# Validate that all parallel styles in the plan are supported
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
def tp_plan(self, plan: dict[str, str] | None):
if plan is None:
self._tp_plan = {}
return
if not isinstance(plan, dict):
raise ValueError("Can only set a dictionary as `tp_plan`")
# Ensure the styles are all valid
for layer_pattern, parallel_style in plan.items():
if parallel_style not in ALL_PARALLEL_STYLES:
raise ValueError(
@ -1921,11 +1927,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
)
# Validate that the layer patterns match existing model structure
# We check this by getting all parameter names and seeing if any match the patterns
if hasattr(self, "named_parameters"):
# Validate that the layer patterns match existing model structure. We check this by getting all parameter
# names and seeing if any match the patterns
model_param_names = [name for name, _ in self.named_parameters()]
if model_param_names: # Only validate if model has parameters
for layer_pattern in plan.keys():
# Convert pattern to regex (replace * with .*)
regex_pattern = layer_pattern.replace("*", r"\d+")
@ -1935,28 +1939,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
pattern_matched = True
break
if not pattern_matched:
# Try more flexible matching - check if pattern components exist
pattern_parts = layer_pattern.split(".")
flexible_matched = False
for param_name in model_param_names:
param_parts = param_name.split(".")
if len(pattern_parts) <= len(param_parts):
match_count = 0
for i, pattern_part in enumerate(pattern_parts):
if pattern_part == "*":
match_count += 1
elif i < len(param_parts) and pattern_part == param_parts[i]:
match_count += 1
if match_count == len(pattern_parts):
flexible_matched = True
break
if not flexible_matched:
warnings.warn(
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
f"This rule may not be applied during tensor parallelization."
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. This rule may not "
"be applied during tensor parallelization, or may lead to dimension mismatches"
)
self._tp_plan = plan if plan is not None else {}
# Set the plan
self._tp_plan = plan
@pp_plan.setter
def pp_plan(self, plan: dict[str, tuple[str, str]]):
@ -4233,10 +4222,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
tp_plan (`str`, *optional*):
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_plan (`Optional[Union[dict, str]]`, *optional*):
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Use `tp_plan="auto"` to
use the predefined plan based on the model. If it's a dict, then it should match between module names and desired layout.
Note that if you use it, you should launch your script accordingly with `torchrun [args] script.py`. This will be much
faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
@ -4333,7 +4323,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
):
key_mapping = cls._checkpoint_conversion_mapping
if distributed_config is not None:
if distributed_config is not None and tp_plan is None:
tp_plan = "auto"
# Not used anymore -- remove them from the kwargs
@ -4371,7 +4361,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
if tp_plan is not None or tp_size is not None: # TP warnings, and setup
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(
device_map, device_mesh, tp_size = initialize_tensor_parallelism(
tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
)
@ -4491,7 +4481,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
model = distribute_model(model, distributed_config, device_mesh, tp_size)
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
# Prepare the full device map
if device_map is not None:

View File

@ -75,9 +75,6 @@ class TestTensorParallel(TestCasePlus):
model_id = "JackFram/llama-68m"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
torch.distributed.barrier()
@ -141,9 +138,6 @@ class TestTensorParallel(TestCasePlus):
model_id = "JackFram/llama-68m"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
torch.distributed.barrier()
@ -155,7 +149,7 @@ class TestTensorParallel(TestCasePlus):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
assert has_dtensor == 1, "TP model must have DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
@ -214,6 +208,40 @@ class TestTensorParallel(TestCasePlus):
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
del non_tp_tensor, tp_tensor
def test_custom_tp_plan(self):
script_to_run = textwrap.dedent(
r"""
import re
import torch
from torch.distributed.tensor import DTensor
from transformers import AutoModelForCausalLM
model_id = "JackFram/llama-68m"
# only shard attentions, but not mlps
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
}
# Use custom tp_plan directly in from_pretrained
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
# Check we can generate with the tp_plan
inputs = torch.randint(100, 200, (1, 10), device=model.device)
out = model.generate(inputs, max_new_tokens=10, do_sample=False)
# Check only the attentions are sharded
for name, param in model.named_parameters():
if re.search(r"\.self_attn\.(q|k|v|o)_proj\.", name):
assert isinstance(param, DTensor)
else:
assert not isinstance(param, DTensor)
"""
)
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
class TestTensorParallelProperties(TestCasePlus):
def test_tp_plan_property_setter_getter(self):