mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 09:03:53 +08:00
Allow passing tp_plan
in from_pretrained
directly (#41435)
* start * allow passing it * fix plans * fix * fix * style * style * fix * add_test * oupsi indent * fix * fix * fix for CI without accelerator * fix import
This commit is contained in:
@ -38,16 +38,15 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
||||
|
||||
|
||||
def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, device_map=None):
|
||||
def initialize_tensor_parallelism(
|
||||
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
|
||||
):
|
||||
r"""
|
||||
Sets up the device mesh and initialized the backend for tensor parallelism.
|
||||
This function is called when the model is loaded and the TP plan is set to 'auto'.
|
||||
"""
|
||||
if tp_size is not None and tp_plan is None:
|
||||
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
if tp_plan is not None and device_map is not None:
|
||||
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
|
||||
if device_mesh is None:
|
||||
@ -80,7 +79,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, devic
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
"We tried to initialize torch.distributed for you, but it failed. Make "
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
||||
"sure you init torch distributed in your script to use `tp_plan`."
|
||||
) from e
|
||||
|
||||
if device_type != "cpu":
|
||||
@ -112,7 +111,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, devic
|
||||
tp_size = device_mesh.size()
|
||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
||||
|
||||
return tp_device, device_map, device_mesh, tp_size
|
||||
return device_map, device_mesh, tp_size
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
||||
@ -1110,13 +1109,16 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
||||
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
||||
|
||||
|
||||
def distribute_model(model, distributed_config, device_mesh, tp_size):
|
||||
def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
|
||||
model._tp_size = tp_size
|
||||
model._device_mesh = device_mesh
|
||||
if distributed_config is not None:
|
||||
if isinstance(distributed_config, dict):
|
||||
distributed_config = DistributedConfig.from_dict(distributed_config)
|
||||
model.config.distributed_config = distributed_config
|
||||
# Set the new requested tp_plan on the model
|
||||
if isinstance(tp_plan, dict):
|
||||
model.tp_plan = tp_plan
|
||||
model_plan = model.tp_plan
|
||||
if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
for v in model_plan.values():
|
||||
|
@ -68,6 +68,7 @@ from .integrations.peft import maybe_load_adapters
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
ALL_PARALLEL_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
distribute_model,
|
||||
initialize_tensor_parallelism,
|
||||
@ -1883,10 +1884,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
f" {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
||||
if self.base_model is self:
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_ep_plan", None):
|
||||
self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
@ -1909,54 +1912,40 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
return self._pp_plan
|
||||
|
||||
@tp_plan.setter
|
||||
def tp_plan(self, plan: dict[str, str]):
|
||||
if plan is not None:
|
||||
# Validate that all parallel styles in the plan are supported
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
|
||||
def tp_plan(self, plan: dict[str, str] | None):
|
||||
if plan is None:
|
||||
self._tp_plan = {}
|
||||
return
|
||||
if not isinstance(plan, dict):
|
||||
raise ValueError("Can only set a dictionary as `tp_plan`")
|
||||
|
||||
for layer_pattern, parallel_style in plan.items():
|
||||
if parallel_style not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
|
||||
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
|
||||
)
|
||||
# Ensure the styles are all valid
|
||||
for layer_pattern, parallel_style in plan.items():
|
||||
if parallel_style not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
|
||||
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
|
||||
)
|
||||
|
||||
# Validate that the layer patterns match existing model structure
|
||||
# We check this by getting all parameter names and seeing if any match the patterns
|
||||
if hasattr(self, "named_parameters"):
|
||||
model_param_names = [name for name, _ in self.named_parameters()]
|
||||
if model_param_names: # Only validate if model has parameters
|
||||
for layer_pattern in plan.keys():
|
||||
# Convert pattern to regex (replace * with .*)
|
||||
regex_pattern = layer_pattern.replace("*", r"\d+")
|
||||
pattern_matched = False
|
||||
for param_name in model_param_names:
|
||||
if re.match(regex_pattern, param_name):
|
||||
pattern_matched = True
|
||||
break
|
||||
if not pattern_matched:
|
||||
# Try more flexible matching - check if pattern components exist
|
||||
pattern_parts = layer_pattern.split(".")
|
||||
flexible_matched = False
|
||||
for param_name in model_param_names:
|
||||
param_parts = param_name.split(".")
|
||||
if len(pattern_parts) <= len(param_parts):
|
||||
match_count = 0
|
||||
for i, pattern_part in enumerate(pattern_parts):
|
||||
if pattern_part == "*":
|
||||
match_count += 1
|
||||
elif i < len(param_parts) and pattern_part == param_parts[i]:
|
||||
match_count += 1
|
||||
if match_count == len(pattern_parts):
|
||||
flexible_matched = True
|
||||
break
|
||||
if not flexible_matched:
|
||||
warnings.warn(
|
||||
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
|
||||
f"This rule may not be applied during tensor parallelization."
|
||||
)
|
||||
# Validate that the layer patterns match existing model structure. We check this by getting all parameter
|
||||
# names and seeing if any match the patterns
|
||||
model_param_names = [name for name, _ in self.named_parameters()]
|
||||
for layer_pattern in plan.keys():
|
||||
# Convert pattern to regex (replace * with .*)
|
||||
regex_pattern = layer_pattern.replace("*", r"\d+")
|
||||
pattern_matched = False
|
||||
for param_name in model_param_names:
|
||||
if re.match(regex_pattern, param_name):
|
||||
pattern_matched = True
|
||||
break
|
||||
if not pattern_matched:
|
||||
warnings.warn(
|
||||
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. This rule may not "
|
||||
"be applied during tensor parallelization, or may lead to dimension mismatches"
|
||||
)
|
||||
|
||||
self._tp_plan = plan if plan is not None else {}
|
||||
# Set the plan
|
||||
self._tp_plan = plan
|
||||
|
||||
@pp_plan.setter
|
||||
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
||||
@ -4233,10 +4222,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
tp_plan (`str`, *optional*):
|
||||
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
|
||||
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
tp_plan (`Optional[Union[dict, str]]`, *optional*):
|
||||
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Use `tp_plan="auto"` to
|
||||
use the predefined plan based on the model. If it's a dict, then it should match between module names and desired layout.
|
||||
Note that if you use it, you should launch your script accordingly with `torchrun [args] script.py`. This will be much
|
||||
faster than using a `device_map`, but has limitations.
|
||||
tp_size (`str`, *optional*):
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
||||
@ -4333,7 +4323,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
):
|
||||
key_mapping = cls._checkpoint_conversion_mapping
|
||||
|
||||
if distributed_config is not None:
|
||||
if distributed_config is not None and tp_plan is None:
|
||||
tp_plan = "auto"
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
@ -4371,7 +4361,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
|
||||
if tp_plan is not None or tp_size is not None: # TP warnings, and setup
|
||||
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(
|
||||
device_map, device_mesh, tp_size = initialize_tensor_parallelism(
|
||||
tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
|
||||
)
|
||||
|
||||
@ -4491,7 +4481,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
|
||||
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
|
||||
|
||||
# Prepare the full device map
|
||||
if device_map is not None:
|
||||
|
@ -75,9 +75,6 @@ class TestTensorParallel(TestCasePlus):
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
@ -141,9 +138,6 @@ class TestTensorParallel(TestCasePlus):
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
@ -155,7 +149,7 @@ class TestTensorParallel(TestCasePlus):
|
||||
has_dtensor = 1
|
||||
break
|
||||
|
||||
assert has_dtensor == 1, "TP model must has DTensor"
|
||||
assert has_dtensor == 1, "TP model must have DTensor"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
prompt = "Can I help"
|
||||
@ -214,6 +208,40 @@ class TestTensorParallel(TestCasePlus):
|
||||
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
|
||||
del non_tp_tensor, tp_tensor
|
||||
|
||||
def test_custom_tp_plan(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
r"""
|
||||
import re
|
||||
import torch
|
||||
from torch.distributed.tensor import DTensor
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
# only shard attentions, but not mlps
|
||||
tp_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise",
|
||||
"model.layers.*.self_attn.k_proj": "colwise",
|
||||
"model.layers.*.self_attn.v_proj": "colwise",
|
||||
"model.layers.*.self_attn.o_proj": "rowwise",
|
||||
}
|
||||
|
||||
# Use custom tp_plan directly in from_pretrained
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
|
||||
|
||||
# Check we can generate with the tp_plan
|
||||
inputs = torch.randint(100, 200, (1, 10), device=model.device)
|
||||
out = model.generate(inputs, max_new_tokens=10, do_sample=False)
|
||||
|
||||
# Check only the attentions are sharded
|
||||
for name, param in model.named_parameters():
|
||||
if re.search(r"\.self_attn\.(q|k|v|o)_proj\.", name):
|
||||
assert isinstance(param, DTensor)
|
||||
else:
|
||||
assert not isinstance(param, DTensor)
|
||||
"""
|
||||
)
|
||||
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||
|
||||
|
||||
class TestTensorParallelProperties(TestCasePlus):
|
||||
def test_tp_plan_property_setter_getter(self):
|
||||
|
Reference in New Issue
Block a user