Improve config handling and add a zoo (#3029)

* Improve config handling and add a zoo

* Docs

* rm comment

* Tweak doc
This commit is contained in:
Zach Mueller
2024-08-20 10:40:21 -04:00
committed by GitHub
parent 52fae0960c
commit 1a6af0bd6d
13 changed files with 157 additions and 39 deletions

View File

@ -157,6 +157,8 @@ accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli). To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).
Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)
## Launching multi-CPU run using MPI ## Launching multi-CPU run using MPI
🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well. 🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
@ -256,7 +258,7 @@ pip install accelerate
- multi-GPU on several nodes (machines) - multi-GPU on several nodes (machines)
- TPU - TPU
- FP16/BFloat16 mixed precision - FP16/BFloat16 mixed precision
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) - FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
- DeepSpeed support (Experimental) - DeepSpeed support (Experimental)
- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental) - PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
- Megatron-LM support (Experimental) - Megatron-LM support (Experimental)

View File

@ -53,6 +53,8 @@ accelerate launch path_to_script.py --args_for_the_script
To learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts. To learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts.
We also have a [configuration zoo](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates) which showcases a number of premade **minimal** example configurations for a variety of setups you can run.
## Adapt training code ## Adapt training code
The next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups. The next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups.

View File

@ -208,23 +208,13 @@ To run it in each of these various modes, use the following commands:
- [huggan project](https://github.com/huggingface/community-events/tree/main/huggan) - [huggan project](https://github.com/huggingface/community-events/tree/main/huggan)
### Using AWS SageMaker integration ### Using AWS SageMaker integration
- [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker) - [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker)
## Configuration zoo
## Simple Multi-GPU Hardware Launcher In [/config_yaml_templates](./config_yaml_templates/) we have a variety of *minimal* `config.yaml` templates and examples to help you learn
how to create your own configuration files depending on the scenario.
[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
run the script to automatically launch multi GPU training on remote hardware.
This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
with `pip install runhouse`, and you can refer to
[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
for hardware setup instructions, or this
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
## SLURM Scripts ## SLURM Scripts
In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager. In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager.
@ -251,6 +241,20 @@ export PYTHONPATH=/home/nct01/nct01328/transformers-in-supercomputers:$PYTHONPAT
export GPUS_PER_NODE=4 export GPUS_PER_NODE=4
``` ```
## Simple Multi-GPU Hardware Launcher (using an external platform)
[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
run the script to automatically launch multi GPU training on remote hardware.
This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
with `pip install runhouse`, and you can refer to
[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
for hardware setup instructions, or this
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
## Finer Examples ## Finer Examples
While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations. While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations.

View File

@ -0,0 +1,10 @@
# Config Zoo
This folder contains a variety of minimal configurations for `Accelerate` achieving certain goals. You can use these
direct config YAML's, or build off of them for your own YAML's.
These are highly annoted versions, aiming to teach you what each section does.
Each config can be run via `accelerate launch --config_file {file} run_me.py`
`run_me.py` will then print out how the current environment is setup (the contents of the `AcceleratorState`)

View File

@ -0,0 +1,15 @@
# Similar to FSDP, we set the distributed type as DEEPSPEED
distributed_type: DEEPSPEED
# With DeepSpeed, we utilize a deepspeed config file for the entire configuration
deepspeed_config:
# Can also be any of the config json's in accelerate/examples/deepspeed_config_templates
deepspeed_config_file: ../deepspeed_config_templates/zero_stage1_config.json
# If using ZeRO-3 and wanting to load big models in, this should be set to `true` so
# `transformers` uses the right `init` function
zero3_init_flag: false # true
# Finally we need to specify the number of GPUs to use
num_processes: 2
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file
# mixed_precision: "bf16"

View File

@ -0,0 +1,18 @@
# This config template simply setups up the TransformersEngine config (and a config for a single GPU),
# this can interop with the other configs in this folder
distributed_type: "NO"
mixed_precision: "fp8"
# Then we specify the fp8 configuration:
fp8_config:
backend: TE # Can be TE | MS-AMP
# The following are TE specific arguments.
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details
amax_history_length: 1024
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
#opt_level: O1

View File

@ -0,0 +1,18 @@
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
distributed_type: FSDP
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
mixed_precision: 'bf16'
# Specify the number of GPUs to use
num_processes: 2
# Then we can specify the FSDP config
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true

View File

@ -0,0 +1,6 @@
# Specify distributed_type as `MULTI_GPU` for DDP
distributed_type: "MULTI_GPU"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"
# Specify the number of GPUs to use
num_processes: 2

View File

@ -0,0 +1,16 @@
# This config template is for a multi-node setup. This assumes DDP, but can be interop'd with the other configs in this folder
# Generally it's recommended to look at the SLURM config template for a more robust multi-node setup
distributed_type: MULTI_GPU
# We need to specify the current machine's rank
machine_rank: 0
# We then need to specify the IP address and port of the main process
main_process_ip: '1234'
main_process_port: 9999
# We need to specify the number of machines
num_machines: 2
# We need to specify the *total* number of processes
num_processes: 8
# And then we need to specify how rdvz comms will be handled
rdzv_backend: static # or c10d
# If the compute nodes are on the same network (cloud will more than likely be false)
same_network: false

View File

@ -0,0 +1,26 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A base script which outputs the accelerate config for the given environment
"""
from accelerate import Accelerator
accelerator = Accelerator()
accelerator.print(f"Accelerator state from the current environment:\n{accelerator.state}")
if accelerator.fp8_recipe_handler is not None:
accelerator.print(f"FP8 config:\n{accelerator.fp8_recipe_handler}")
accelerator.end_training()

View File

@ -0,0 +1,4 @@
# Since this is single GPU, we don't need distributed training
distributed_type: "NO"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"

View File

@ -99,13 +99,17 @@ class BaseConfig:
result = {k: v for k, v in result.items() if v is not None} result = {k: v for k, v in result.items() if v is not None}
return result return result
@classmethod @staticmethod
def from_json_file(cls, json_file=None): def process_config(config_dict):
json_file = default_json_config_file if json_file is None else json_file """
with open(json_file, encoding="utf-8") as f: Processes `config_dict` and sets default values for any missing keys
config_dict = json.load(f) """
if "compute_environment" not in config_dict: if "compute_environment" not in config_dict:
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
if "distributed_type" not in config_dict:
raise ValueError("A `distributed_type` must be specified in the config file.")
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
config_dict["num_processes"] = 1
if "mixed_precision" not in config_dict: if "mixed_precision" not in config_dict:
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
if "fp16" in config_dict: # Convert the config to the new format. if "fp16" in config_dict: # Convert the config to the new format.
@ -119,6 +123,14 @@ class BaseConfig:
config_dict["debug"] = False config_dict["debug"] = False
if "enable_cpu_affinity" not in config_dict: if "enable_cpu_affinity" not in config_dict:
config_dict["enable_cpu_affinity"] = False config_dict["enable_cpu_affinity"] = False
return config_dict
@classmethod
def from_json_file(cls, json_file=None):
json_file = default_json_config_file if json_file is None else json_file
with open(json_file, encoding="utf-8") as f:
config_dict = json.load(f)
config_dict = cls.process_config(config_dict)
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0: if len(extra_keys) > 0:
raise ValueError( raise ValueError(
@ -138,23 +150,7 @@ class BaseConfig:
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
with open(yaml_file, encoding="utf-8") as f: with open(yaml_file, encoding="utf-8") as f:
config_dict = yaml.safe_load(f) config_dict = yaml.safe_load(f)
if "compute_environment" not in config_dict: config_dict = cls.process_config(config_dict)
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
if "mixed_precision" not in config_dict:
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]:
config_dict["mixed_precision"] = "no"
if "fp16" in config_dict: # Convert the config to the new format.
del config_dict["fp16"]
if "dynamo_backend" in config_dict: # Convert the config to the new format.
dynamo_backend = config_dict.pop("dynamo_backend")
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
if "use_cpu" not in config_dict:
config_dict["use_cpu"] = False
if "debug" not in config_dict:
config_dict["debug"] = False
if "enable_cpu_affinity" not in config_dict:
config_dict["enable_cpu_affinity"] = False
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0: if len(extra_keys) > 0:
raise ValueError( raise ValueError(

View File

@ -1338,10 +1338,11 @@ class FullyShardedDataParallelPlugin:
}, },
) )
sync_module_states: bool = field( sync_module_states: bool = field(
default=False, default=None,
metadata={ metadata={
"help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 "
"to ensure they are the same across all ranks after initialization. Defaults to `True`" "to ensure they are the same across all ranks after initialization. Defaults to `False` unless "
"`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled."
}, },
) )
forward_prefetch: bool = field( forward_prefetch: bool = field(