Compare commits

...

7 Commits

Author SHA1 Message Date
d4b01f2fd3 Fix: correct sharding dims 2025-04-10 21:51:57 +00:00
f13e867c7b WIP: boiler plate for dp+tp 2025-04-10 16:04:41 +00:00
c07373687d Merge branch 'main' into tp-size 2025-04-10 01:46:49 +02:00
d77505c06e Merge branch 'main' into tp-size 2025-04-09 17:50:12 +02:00
a65130c6fa fix: nit in docs
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2025-04-09 18:56:38 +05:30
bb2950d149 fix: review cmt - error when tp_plan not set for tp_size
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2025-04-09 18:56:38 +05:30
1059fffb45 feat: custom tp_size, new transformers tp interface
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2025-04-09 18:56:38 +05:30
7 changed files with 79 additions and 126 deletions

View File

@ -674,29 +674,7 @@ use_cpu: false
```
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.

View File

@ -341,29 +341,9 @@ use_cpu: false
```
</hfoption>
<hfoption id="Tensor parallelism with PyTorch 2">
```yaml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoptions>
Run [accelerate_launch](https://hf.co/docs/accelerate/package_reference/cli#accelerate-launch) to start training with the configurations set in `config_file.yaml`. This file is saved to the Accelerate cache folder and automatically loaded when you run `accelerate_launch`.
The example below launches the [run_glue.py](../../../examples/pytorch/text-classification/run_glue) script with the FSDP configuration shown earlier. Parameters from the `config_file.yaml` file can also be directly set in the command line.

View File

@ -363,29 +363,6 @@ use_cpu: false
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>
El comando [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) es la forma recomendada de lanzar tu script de entrenamiento en un sistema distribuido con Accelerate y [`Trainer`] con los parámetros especificados en `config_file.yaml`. Este archivo se guarda en la carpeta de caché de Accelerate y se carga automáticamente cuando ejecutas `accelerate_launch`.

View File

@ -549,29 +549,7 @@ use_cpu: false
```
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>
[`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) 명령은 Accelerate와 [`Trainer`]를 사용하여 분산 시스템에서 훈련 스크립트를 실행하는 권장 방법이며, `config_file.yaml`에 지정된 매개변수를 사용합니다. 이 파일은 Accelerate 캐시 폴더에 저장되며 `accelerate_launch`를 실행할 때 자동으로 로드됩니다.

View File

@ -734,6 +734,7 @@ def _load_state_dict_into_meta_model(
)
if device_mesh is not None: # In this case, the param is already on the correct device!
rank = device_mesh.get_local_rank("tp")
shard_and_distribute_module(
model,
param,
@ -741,8 +742,8 @@ def _load_state_dict_into_meta_model(
param_name,
casting_dtype,
to_contiguous,
int(os.environ["RANK"]), # the rank
device_mesh,
rank, # the rank
device_mesh["tp"],
)
else:
param = param[...]
@ -1784,6 +1785,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# for example.
_tp_plan = None
# tensor parallel degree to which model is sharded to.
_tp_size = None
# data parallel degree to be used, if any
# is used to be forwarded to accelerate to get the correct device mesh
_dp_size = None
# A pipeline parallel plan specifying the layers which may not be present
# on all ranks when PP is enabled. For top-level models, this attribute is
# currently defined in respective model code. For base models, this
@ -3845,6 +3853,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
dp_size (`int`, *optional*):
A torch data parallel degree. Is only used to create the correct device mesh if `tp_size` is provided.
Only used by Accelerate to compose TP + FSDP/DDP.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@ -3941,6 +3954,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
dp_size = kwargs.pop("dp_size", None)
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
@ -3953,7 +3968,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError(
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
)
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
@ -3961,6 +3977,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError(
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
)
if tp_size is None and dp_size is not None:
dp_size = None
# If torchrun was used, make sure to TP by default. This way people don't need to change tp or device map
if device_map == "auto" and tp_plan is None and int(os.environ.get("WORLD_SIZE", 0)):
@ -4007,9 +4025,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
if dp_size is not None and tp_size * dp_size != torch.distributed.get_world_size():
raise ValueError(
f"tp_size * dp_size ({tp_size} * {dp_size}) must be equal to the world size {torch.distributed.get_world_size()}."
)
device_mesh_shape = (
(
dp_size,
tp_size,
)
if dp_size
else (tp_size,)
)
mesh_dim_names = (
(
"dp",
"tp",
)
if dp_size
else ("tp",)
)
device_mesh = torch.distributed.init_device_mesh(
tp_device.type, device_mesh_shape, mesh_dim_names=mesh_dim_names
)
if use_auth_token is not None:
warnings.warn(
@ -4368,11 +4409,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_regex=keep_in_fp32_regex,
device_mesh=device_mesh,
device_mesh=device_mesh["tp"],
key_mapping=key_mapping,
weights_only=weights_only,
)
# record tp/dp info for accelerate
model._tp_size = tp_size
model._dp_size = dp_size
# make sure token embedding weights are still tied if needed
model.tie_weights()
@ -4456,7 +4501,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif from_flax:
loading_info = None
return model, loading_info
return model
@staticmethod
@ -4804,7 +4848,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
device_mesh=device_mesh["tp"],
)
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
@ -4854,6 +4898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for name, param in parameters_to_initialize.items():
# First move data to correct
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
rank = device_mesh.get_local_rank("tp")
shard_and_distribute_module(
model,
param.to(tp_device),
@ -4861,8 +4906,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
name,
casting_dtype,
to_contiguous,
os.environ["RANK"],
device_mesh,
rank,
device_mesh["tp"],
)
# All potential warnings/infos
@ -5100,6 +5145,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return True
return False
@property
def tp_size(self):
"""
Returns the model's tensor parallelism degree.
"""
# if None, the model didn't undergo tensor parallel sharding
return self._tp_size
@property
def dp_size(self):
"""
Returns the model's data parallelism degree.
"""
return self._dp_size
@property
def supports_pp_plan(self):
if self._pp_plan is not None:

View File

@ -459,7 +459,7 @@ class Trainer:
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
self.model = model
self.create_accelerator_and_postprocess()
# memory metrics - must set up as early as possible
@ -5137,12 +5137,16 @@ class Trainer:
args.update(accelerator_config)
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
if hasattr(self.model, "tp_size") and self.model.tp_size is not None and self.model.tp_size > 1:
self.is_tp_enabled = True
if version.parse(accelerate_version) > version.parse("1.3.0"):
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.model.tp_size)
args["tp_size"] = self.model.tp_size
else:
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
if hasattr(self.model, "dp_size") and self.model.dp_size is not None and self.model.dp_size > 1:
# TODO: version check
args["dp_size"] = self.model.dp_size
# create accelerator object
self.accelerator = Accelerator(**args)

View File

@ -554,10 +554,6 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
tp_size (`int`, *optional*):
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
in their respective config classes.
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@ -1244,18 +1240,6 @@ class TrainingArguments:
)
},
)
tp_size: Optional[int] = field(
default=0,
metadata={
"help": (
"Use tp_size to enable pytorch tensor parallelism."
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
"Set a value greater than 1 to activate TP."
"The same is used to prepare device mesh internally."
"Requires accelerate>1.3.0."
)
},
)
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
default=None,
metadata={
@ -1941,14 +1925,6 @@ class TrainingArguments:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
if self.tp_size > 1:
if not is_accelerate_available("1.3.1"):
raise NotImplementedError(
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
"This is not supported and we recommend you to update your version."
)
os.environ["ACCELERATE_USE_TP"] = "true"
os.environ["TP_SIZE"] = str(self.tp_size)
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"