mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
7 Commits
740f952218
...
composable
Author | SHA1 | Date | |
---|---|---|---|
d4b01f2fd3 | |||
f13e867c7b | |||
c07373687d | |||
d77505c06e | |||
a65130c6fa | |||
bb2950d149 | |||
1059fffb45 |
@ -674,29 +674,7 @@ use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.
|
||||
|
||||
|
@ -341,29 +341,9 @@ use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor parallelism with PyTorch 2">
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
</hfoptions>
|
||||
|
||||
|
||||
Run [accelerate_launch](https://hf.co/docs/accelerate/package_reference/cli#accelerate-launch) to start training with the configurations set in `config_file.yaml`. This file is saved to the Accelerate cache folder and automatically loaded when you run `accelerate_launch`.
|
||||
|
||||
The example below launches the [run_glue.py](../../../examples/pytorch/text-classification/run_glue) script with the FSDP configuration shown earlier. Parameters from the `config_file.yaml` file can also be directly set in the command line.
|
||||
|
@ -363,29 +363,6 @@ use_cpu: false
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
El comando [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) es la forma recomendada de lanzar tu script de entrenamiento en un sistema distribuido con Accelerate y [`Trainer`] con los parámetros especificados en `config_file.yaml`. Este archivo se guarda en la carpeta de caché de Accelerate y se carga automáticamente cuando ejecutas `accelerate_launch`.
|
||||
|
@ -549,29 +549,7 @@ use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
[`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) 명령은 Accelerate와 [`Trainer`]를 사용하여 분산 시스템에서 훈련 스크립트를 실행하는 권장 방법이며, `config_file.yaml`에 지정된 매개변수를 사용합니다. 이 파일은 Accelerate 캐시 폴더에 저장되며 `accelerate_launch`를 실행할 때 자동으로 로드됩니다.
|
||||
|
@ -734,6 +734,7 @@ def _load_state_dict_into_meta_model(
|
||||
)
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
rank = device_mesh.get_local_rank("tp")
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param,
|
||||
@ -741,8 +742,8 @@ def _load_state_dict_into_meta_model(
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
int(os.environ["RANK"]), # the rank
|
||||
device_mesh,
|
||||
rank, # the rank
|
||||
device_mesh["tp"],
|
||||
)
|
||||
else:
|
||||
param = param[...]
|
||||
@ -1784,6 +1785,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# for example.
|
||||
_tp_plan = None
|
||||
|
||||
# tensor parallel degree to which model is sharded to.
|
||||
_tp_size = None
|
||||
|
||||
# data parallel degree to be used, if any
|
||||
# is used to be forwarded to accelerate to get the correct device mesh
|
||||
_dp_size = None
|
||||
|
||||
# A pipeline parallel plan specifying the layers which may not be present
|
||||
# on all ranks when PP is enabled. For top-level models, this attribute is
|
||||
# currently defined in respective model code. For base models, this
|
||||
@ -3845,6 +3853,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
|
||||
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
tp_size (`str`, *optional*):
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
dp_size (`int`, *optional*):
|
||||
A torch data parallel degree. Is only used to create the correct device mesh if `tp_size` is provided.
|
||||
Only used by Accelerate to compose TP + FSDP/DDP.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
@ -3941,6 +3954,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
generation_config = kwargs.pop("generation_config", None)
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
dp_size = kwargs.pop("dp_size", None)
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
@ -3953,7 +3968,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
raise ValueError(
|
||||
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
|
||||
)
|
||||
|
||||
if tp_size is not None and tp_plan is None:
|
||||
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
@ -3961,6 +3977,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
raise ValueError(
|
||||
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
|
||||
)
|
||||
if tp_size is None and dp_size is not None:
|
||||
dp_size = None
|
||||
|
||||
# If torchrun was used, make sure to TP by default. This way people don't need to change tp or device map
|
||||
if device_map == "auto" and tp_plan is None and int(os.environ.get("WORLD_SIZE", 0)):
|
||||
@ -4007,9 +4025,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
|
||||
# Assuming sharding the model onto the world when tp_size not provided
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
if dp_size is not None and tp_size * dp_size != torch.distributed.get_world_size():
|
||||
raise ValueError(
|
||||
f"tp_size * dp_size ({tp_size} * {dp_size}) must be equal to the world size {torch.distributed.get_world_size()}."
|
||||
)
|
||||
device_mesh_shape = (
|
||||
(
|
||||
dp_size,
|
||||
tp_size,
|
||||
)
|
||||
if dp_size
|
||||
else (tp_size,)
|
||||
)
|
||||
mesh_dim_names = (
|
||||
(
|
||||
"dp",
|
||||
"tp",
|
||||
)
|
||||
if dp_size
|
||||
else ("tp",)
|
||||
)
|
||||
device_mesh = torch.distributed.init_device_mesh(
|
||||
tp_device.type, device_mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@ -4368,11 +4409,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
dtype=torch_dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
device_mesh=device_mesh,
|
||||
device_mesh=device_mesh["tp"],
|
||||
key_mapping=key_mapping,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
# record tp/dp info for accelerate
|
||||
model._tp_size = tp_size
|
||||
model._dp_size = dp_size
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
@ -4456,7 +4501,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
elif from_flax:
|
||||
loading_info = None
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
@ -4804,7 +4848,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
device_mesh=device_mesh["tp"],
|
||||
)
|
||||
|
||||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
|
||||
@ -4854,6 +4898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for name, param in parameters_to_initialize.items():
|
||||
# First move data to correct
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
|
||||
rank = device_mesh.get_local_rank("tp")
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param.to(tp_device),
|
||||
@ -4861,8 +4906,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
os.environ["RANK"],
|
||||
device_mesh,
|
||||
rank,
|
||||
device_mesh["tp"],
|
||||
)
|
||||
|
||||
# All potential warnings/infos
|
||||
@ -5100,6 +5145,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
"""
|
||||
Returns the model's tensor parallelism degree.
|
||||
"""
|
||||
# if None, the model didn't undergo tensor parallel sharding
|
||||
return self._tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
"""
|
||||
Returns the model's data parallelism degree.
|
||||
"""
|
||||
return self._dp_size
|
||||
|
||||
@property
|
||||
def supports_pp_plan(self):
|
||||
if self._pp_plan is not None:
|
||||
|
@ -459,7 +459,7 @@ class Trainer:
|
||||
self.hp_name = None
|
||||
self.deepspeed = None
|
||||
self.is_in_train = False
|
||||
|
||||
self.model = model
|
||||
self.create_accelerator_and_postprocess()
|
||||
|
||||
# memory metrics - must set up as early as possible
|
||||
@ -5137,12 +5137,16 @@ class Trainer:
|
||||
args.update(accelerator_config)
|
||||
# tp is initialized at Accelerator init phase so
|
||||
# args should be prepared here
|
||||
if self.args.tp_size > 1:
|
||||
if hasattr(self.model, "tp_size") and self.model.tp_size is not None and self.model.tp_size > 1:
|
||||
self.is_tp_enabled = True
|
||||
if version.parse(accelerate_version) > version.parse("1.3.0"):
|
||||
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
|
||||
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.model.tp_size)
|
||||
args["tp_size"] = self.model.tp_size
|
||||
else:
|
||||
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
|
||||
if hasattr(self.model, "dp_size") and self.model.dp_size is not None and self.model.dp_size > 1:
|
||||
# TODO: version check
|
||||
args["dp_size"] = self.model.dp_size
|
||||
|
||||
# create accelerator object
|
||||
self.accelerator = Accelerator(**args)
|
||||
|
@ -554,10 +554,6 @@ class TrainingArguments:
|
||||
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
|
||||
used when the xla flag is set to true, and an auto wrapping policy is specified through
|
||||
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
|
||||
tp_size (`int`, *optional*):
|
||||
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
|
||||
in their respective config classes.
|
||||
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
|
||||
deepspeed (`str` or `dict`, *optional*):
|
||||
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
|
||||
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
||||
@ -1244,18 +1240,6 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
tp_size: Optional[int] = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": (
|
||||
"Use tp_size to enable pytorch tensor parallelism."
|
||||
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
|
||||
"Set a value greater than 1 to activate TP."
|
||||
"The same is used to prepare device mesh internally."
|
||||
"Requires accelerate>1.3.0."
|
||||
)
|
||||
},
|
||||
)
|
||||
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1941,14 +1925,6 @@ class TrainingArguments:
|
||||
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
||||
|
||||
if self.tp_size > 1:
|
||||
if not is_accelerate_available("1.3.1"):
|
||||
raise NotImplementedError(
|
||||
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
|
||||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||
os.environ["TP_SIZE"] = str(self.tp_size)
|
||||
# accelerate integration for FSDP
|
||||
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
|
Reference in New Issue
Block a user