mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Allow passing tp_plan
in from_pretrained
directly (#41435)
* start * allow passing it * fix plans * fix * fix * style * style * fix * add_test * oupsi indent * fix * fix * fix for CI without accelerator * fix import
This commit is contained in:
@ -38,16 +38,15 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
|||||||
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
||||||
|
|
||||||
|
|
||||||
def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, device_map=None):
|
def initialize_tensor_parallelism(
|
||||||
|
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Sets up the device mesh and initialized the backend for tensor parallelism.
|
Sets up the device mesh and initialized the backend for tensor parallelism.
|
||||||
This function is called when the model is loaded and the TP plan is set to 'auto'.
|
This function is called when the model is loaded and the TP plan is set to 'auto'.
|
||||||
"""
|
"""
|
||||||
if tp_size is not None and tp_plan is None:
|
if tp_size is not None and tp_plan is None:
|
||||||
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
||||||
if tp_plan is not None and tp_plan != "auto":
|
|
||||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
|
||||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
|
||||||
if tp_plan is not None and device_map is not None:
|
if tp_plan is not None and device_map is not None:
|
||||||
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
|
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
|
||||||
if device_mesh is None:
|
if device_mesh is None:
|
||||||
@ -80,7 +79,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, devic
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
"We tried to initialize torch.distributed for you, but it failed. Make "
|
"We tried to initialize torch.distributed for you, but it failed. Make "
|
||||||
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
"sure you init torch distributed in your script to use `tp_plan`."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if device_type != "cpu":
|
if device_type != "cpu":
|
||||||
@ -112,7 +111,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None, device_mesh=None, devic
|
|||||||
tp_size = device_mesh.size()
|
tp_size = device_mesh.size()
|
||||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
||||||
|
|
||||||
return tp_device, device_map, device_mesh, tp_size
|
return device_map, device_mesh, tp_size
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
||||||
@ -1110,13 +1109,16 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
|||||||
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
||||||
|
|
||||||
|
|
||||||
def distribute_model(model, distributed_config, device_mesh, tp_size):
|
def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
|
||||||
model._tp_size = tp_size
|
model._tp_size = tp_size
|
||||||
model._device_mesh = device_mesh
|
model._device_mesh = device_mesh
|
||||||
if distributed_config is not None:
|
if distributed_config is not None:
|
||||||
if isinstance(distributed_config, dict):
|
if isinstance(distributed_config, dict):
|
||||||
distributed_config = DistributedConfig.from_dict(distributed_config)
|
distributed_config = DistributedConfig.from_dict(distributed_config)
|
||||||
model.config.distributed_config = distributed_config
|
model.config.distributed_config = distributed_config
|
||||||
|
# Set the new requested tp_plan on the model
|
||||||
|
if isinstance(tp_plan, dict):
|
||||||
|
model.tp_plan = tp_plan
|
||||||
model_plan = model.tp_plan
|
model_plan = model.tp_plan
|
||||||
if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||||
for v in model_plan.values():
|
for v in model_plan.values():
|
||||||
|
@ -68,6 +68,7 @@ from .integrations.peft import maybe_load_adapters
|
|||||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||||
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
from .integrations.sdpa_paged import sdpa_attention_paged_forward
|
||||||
from .integrations.tensor_parallel import (
|
from .integrations.tensor_parallel import (
|
||||||
|
ALL_PARALLEL_STYLES,
|
||||||
_get_parameter_tp_plan,
|
_get_parameter_tp_plan,
|
||||||
distribute_model,
|
distribute_model,
|
||||||
initialize_tensor_parallelism,
|
initialize_tensor_parallelism,
|
||||||
@ -1883,10 +1884,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
f" {self.__class__.__name__}"
|
f" {self.__class__.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
|
||||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
if self.base_model is self:
|
||||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
||||||
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||||
|
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
||||||
for name, module in self.named_children():
|
for name, module in self.named_children():
|
||||||
if plan := getattr(module, "_ep_plan", None):
|
if plan := getattr(module, "_ep_plan", None):
|
||||||
self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||||
@ -1909,54 +1912,40 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
return self._pp_plan
|
return self._pp_plan
|
||||||
|
|
||||||
@tp_plan.setter
|
@tp_plan.setter
|
||||||
def tp_plan(self, plan: dict[str, str]):
|
def tp_plan(self, plan: dict[str, str] | None):
|
||||||
if plan is not None:
|
if plan is None:
|
||||||
# Validate that all parallel styles in the plan are supported
|
self._tp_plan = {}
|
||||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
|
return
|
||||||
|
if not isinstance(plan, dict):
|
||||||
|
raise ValueError("Can only set a dictionary as `tp_plan`")
|
||||||
|
|
||||||
for layer_pattern, parallel_style in plan.items():
|
# Ensure the styles are all valid
|
||||||
if parallel_style not in ALL_PARALLEL_STYLES:
|
for layer_pattern, parallel_style in plan.items():
|
||||||
raise ValueError(
|
if parallel_style not in ALL_PARALLEL_STYLES:
|
||||||
f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
|
raise ValueError(
|
||||||
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
|
f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
|
||||||
)
|
f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate that the layer patterns match existing model structure
|
# Validate that the layer patterns match existing model structure. We check this by getting all parameter
|
||||||
# We check this by getting all parameter names and seeing if any match the patterns
|
# names and seeing if any match the patterns
|
||||||
if hasattr(self, "named_parameters"):
|
model_param_names = [name for name, _ in self.named_parameters()]
|
||||||
model_param_names = [name for name, _ in self.named_parameters()]
|
for layer_pattern in plan.keys():
|
||||||
if model_param_names: # Only validate if model has parameters
|
# Convert pattern to regex (replace * with .*)
|
||||||
for layer_pattern in plan.keys():
|
regex_pattern = layer_pattern.replace("*", r"\d+")
|
||||||
# Convert pattern to regex (replace * with .*)
|
pattern_matched = False
|
||||||
regex_pattern = layer_pattern.replace("*", r"\d+")
|
for param_name in model_param_names:
|
||||||
pattern_matched = False
|
if re.match(regex_pattern, param_name):
|
||||||
for param_name in model_param_names:
|
pattern_matched = True
|
||||||
if re.match(regex_pattern, param_name):
|
break
|
||||||
pattern_matched = True
|
if not pattern_matched:
|
||||||
break
|
warnings.warn(
|
||||||
if not pattern_matched:
|
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. This rule may not "
|
||||||
# Try more flexible matching - check if pattern components exist
|
"be applied during tensor parallelization, or may lead to dimension mismatches"
|
||||||
pattern_parts = layer_pattern.split(".")
|
)
|
||||||
flexible_matched = False
|
|
||||||
for param_name in model_param_names:
|
|
||||||
param_parts = param_name.split(".")
|
|
||||||
if len(pattern_parts) <= len(param_parts):
|
|
||||||
match_count = 0
|
|
||||||
for i, pattern_part in enumerate(pattern_parts):
|
|
||||||
if pattern_part == "*":
|
|
||||||
match_count += 1
|
|
||||||
elif i < len(param_parts) and pattern_part == param_parts[i]:
|
|
||||||
match_count += 1
|
|
||||||
if match_count == len(pattern_parts):
|
|
||||||
flexible_matched = True
|
|
||||||
break
|
|
||||||
if not flexible_matched:
|
|
||||||
warnings.warn(
|
|
||||||
f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
|
|
||||||
f"This rule may not be applied during tensor parallelization."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._tp_plan = plan if plan is not None else {}
|
# Set the plan
|
||||||
|
self._tp_plan = plan
|
||||||
|
|
||||||
@pp_plan.setter
|
@pp_plan.setter
|
||||||
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
||||||
@ -4233,10 +4222,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
max_memory (`Dict`, *optional*):
|
max_memory (`Dict`, *optional*):
|
||||||
A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
|
A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
|
||||||
GPU and the available CPU RAM if unset.
|
GPU and the available CPU RAM if unset.
|
||||||
tp_plan (`str`, *optional*):
|
tp_plan (`Optional[Union[dict, str]]`, *optional*):
|
||||||
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
|
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Use `tp_plan="auto"` to
|
||||||
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
|
use the predefined plan based on the model. If it's a dict, then it should match between module names and desired layout.
|
||||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
Note that if you use it, you should launch your script accordingly with `torchrun [args] script.py`. This will be much
|
||||||
|
faster than using a `device_map`, but has limitations.
|
||||||
tp_size (`str`, *optional*):
|
tp_size (`str`, *optional*):
|
||||||
A torch tensor parallel degree. If not provided would default to world size.
|
A torch tensor parallel degree. If not provided would default to world size.
|
||||||
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
||||||
@ -4333,7 +4323,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
):
|
):
|
||||||
key_mapping = cls._checkpoint_conversion_mapping
|
key_mapping = cls._checkpoint_conversion_mapping
|
||||||
|
|
||||||
if distributed_config is not None:
|
if distributed_config is not None and tp_plan is None:
|
||||||
tp_plan = "auto"
|
tp_plan = "auto"
|
||||||
|
|
||||||
# Not used anymore -- remove them from the kwargs
|
# Not used anymore -- remove them from the kwargs
|
||||||
@ -4371,7 +4361,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
)
|
)
|
||||||
|
|
||||||
if tp_plan is not None or tp_size is not None: # TP warnings, and setup
|
if tp_plan is not None or tp_size is not None: # TP warnings, and setup
|
||||||
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(
|
device_map, device_mesh, tp_size = initialize_tensor_parallelism(
|
||||||
tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
|
tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4491,7 +4481,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
)
|
)
|
||||||
|
|
||||||
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
||||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
|
||||||
|
|
||||||
# Prepare the full device map
|
# Prepare the full device map
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
|
@ -75,9 +75,6 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
|
|
||||||
model_id = "JackFram/llama-68m"
|
model_id = "JackFram/llama-68m"
|
||||||
|
|
||||||
rank = int(os.environ["RANK"])
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
@ -141,9 +138,6 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
|
|
||||||
model_id = "JackFram/llama-68m"
|
model_id = "JackFram/llama-68m"
|
||||||
|
|
||||||
rank = int(os.environ["RANK"])
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
@ -155,7 +149,7 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
has_dtensor = 1
|
has_dtensor = 1
|
||||||
break
|
break
|
||||||
|
|
||||||
assert has_dtensor == 1, "TP model must has DTensor"
|
assert has_dtensor == 1, "TP model must have DTensor"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
prompt = "Can I help"
|
prompt = "Can I help"
|
||||||
@ -214,6 +208,40 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
|
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
|
||||||
del non_tp_tensor, tp_tensor
|
del non_tp_tensor, tp_tensor
|
||||||
|
|
||||||
|
def test_custom_tp_plan(self):
|
||||||
|
script_to_run = textwrap.dedent(
|
||||||
|
r"""
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
model_id = "JackFram/llama-68m"
|
||||||
|
# only shard attentions, but not mlps
|
||||||
|
tp_plan = {
|
||||||
|
"model.layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"model.layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"model.layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"model.layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use custom tp_plan directly in from_pretrained
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
|
||||||
|
|
||||||
|
# Check we can generate with the tp_plan
|
||||||
|
inputs = torch.randint(100, 200, (1, 10), device=model.device)
|
||||||
|
out = model.generate(inputs, max_new_tokens=10, do_sample=False)
|
||||||
|
|
||||||
|
# Check only the attentions are sharded
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if re.search(r"\.self_attn\.(q|k|v|o)_proj\.", name):
|
||||||
|
assert isinstance(param, DTensor)
|
||||||
|
else:
|
||||||
|
assert not isinstance(param, DTensor)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||||
|
|
||||||
|
|
||||||
class TestTensorParallelProperties(TestCasePlus):
|
class TestTensorParallelProperties(TestCasePlus):
|
||||||
def test_tp_plan_property_setter_getter(self):
|
def test_tp_plan_property_setter_getter(self):
|
||||||
|
Reference in New Issue
Block a user