mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
DeepSpeed Revamp (#405)
* deepspeed revamp * Update dataclasses.py * Update deepspeed.py * quality * fixing code * quality * FIx imports * saving 16bit model in zero stage 3 1. Saving 16bit model in zero stage 3 2. zero init in stage 3 support using HFDeepSpeedConfig * quality * adding test and fixing bugs * update makefile for deepspeed tests * Update test.yml * adding `deepspeed` as requirement for tests * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * quality * addressing comments * add example and minor updates 1. Add example to show the usage of config file with revamped deepspeed support. 2. update required deepspeed version to 0.6.5 2. reverting `reinit` change as it is not required, 3. raising Exception when using `clip_grad_value` with DeepSpeed/FSDP. * Documentation and Zero-3 Inference Support 1. Changes to support ZeRo Stage-3 Inference support. 2. minor bug fixes. 3. Documentation. * doc fix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * addressing comments * update doc to address comments and bug fixes 1. update tests and add new one testing autofill functionality of `prepare` method. 2. fix bug related to zero-3 init related to HFDeepSpeedConfig 3. Update documentation addressing comments. * removing image and hosting it on `documentation-images` dataset * check for hidden_size for zero_opt heurisitics Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
05c641bc0c
commit
1703b79a79
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@ -13,6 +13,7 @@ jobs:
|
||||
matrix:
|
||||
test-kind: [
|
||||
test,
|
||||
test_deepspeed,
|
||||
test_example_differences,
|
||||
test_checkpoint_step,
|
||||
test_checkpoint_epoch,
|
||||
|
3
Makefile
3
Makefile
@ -27,6 +27,9 @@ style:
|
||||
test:
|
||||
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py
|
||||
|
||||
test_deepspeed:
|
||||
python -m pytest -s -v ./tests/deepspeed
|
||||
|
||||
test_examples:
|
||||
python -m pytest -s -v ./tests/test_examples.py
|
||||
|
||||
|
@ -29,4 +29,6 @@
|
||||
title: Fully Sharded Data Parallel
|
||||
- local: memory
|
||||
title: Memory Utilities
|
||||
- local: deepspeed
|
||||
title: DeepSpeed
|
||||
title: API Reference
|
||||
|
508
docs/source/deepspeed.mdx
Normal file
508
docs/source/deepspeed.mdx
Normal file
@ -0,0 +1,508 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DeepSpeed
|
||||
|
||||
[DeepSpeed](https://github.com/microsoft/DeepSpeed) implements everything described in the [ZeRO paper](https://arxiv.org/abs/1910.02054). Currently it provides full support for:
|
||||
|
||||
1. Optimizer state partitioning (ZeRO stage 1)
|
||||
2. Gradient partitioning (ZeRO stage 2)
|
||||
3. Parameter partitioning (ZeRO stage 3)
|
||||
4. Custom mixed precision training handling
|
||||
5. A range of fast CUDA-extension-based optimizers
|
||||
6. ZeRO-Offload to CPU and Disk/NVMe
|
||||
|
||||
ZeRO-Offload has its own dedicated paper: [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840). And NVMe-support is described in the paper [ZeRO-Infinity: Breaking the GPU
|
||||
Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857).
|
||||
|
||||
DeepSpeed ZeRO-2 is primarily used only for training, as its features are of no use to inference.
|
||||
|
||||
DeepSpeed ZeRO-3 can be used for inference as well, since it allows huge models to be loaded on multiple GPUs, which
|
||||
won't be possible on a single GPU.
|
||||
|
||||
🤗 Accelerate integrates [DeepSpeed](https://github.com/microsoft/DeepSpeed) via 2 options:
|
||||
|
||||
1. Integration of the DeepSpeed features via `deepspeed config file` specification in `accelerate config` . You just supply your custom config file or use our template. Most of
|
||||
this document is focused on this feature. This supports all the core features of DeepSpeed and gives user a lot of flexibility.
|
||||
User may have to change few lines of code depending on the config.
|
||||
2. Integration via `deepspeed_plugin`.This supports subset of the DeepSpeed features and uses default options for the rest of the configurations.
|
||||
User need not change any code and is good for those who are fine with most of the default settings of DeepSpeed.
|
||||
|
||||
## What is integrated?
|
||||
|
||||
Training:
|
||||
|
||||
1. DeepSpeed ZeRO training supports the full ZeRO stages 1, 2 and 3 as well as CPU/Disk offload of optimizer states, gradients and parameters.
|
||||
Below is a short description of Data Parallelism using ZeRO - Zero Redundancy Optimizer along with diagram from this [blog post](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
|
||||

|
||||
|
||||
(Source: [link](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/))
|
||||
|
||||
a. **Stage 1** : Shards optimizer states across data parallel workers/GPUs
|
||||
|
||||
b. **Stage 2** : Shards optimizer states + gradients across data parallel workers/GPUs
|
||||
|
||||
c. **Stage 3**: Shards optimizer states + gradients + model parameters across data parallel workers/GPUs
|
||||
|
||||
d. **Optimizer Offload**: Offloads the gradients + optimizer states to CPU/Disk building on top of ZERO Stage 2
|
||||
|
||||
e. **Param Offload**: Offloads the model parameters to CPU/Disk building on top of ZERO Stage 3
|
||||
|
||||
<u>Note</u>: With respect to Disk Offload, the disk should be an NVME for decent speed but it technically work on any Disk
|
||||
|
||||
Inference:
|
||||
|
||||
1. DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but
|
||||
it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant. For more details see:
|
||||
[deepspeed-zero-inference](#deepspeed-zero-inference).
|
||||
|
||||
|
||||
## How it works?
|
||||
|
||||
**Pre-Requisites**: Install DeepSpeed version >=0.6.5. Please refer to the [DeepSpeed Insallation details](https://github.com/microsoft/DeepSpeed#installation)
|
||||
for more information.
|
||||
|
||||
We will first look at easy to use integration via `accelerate config`.
|
||||
Followed by more flexible and feature rich `deepspeed config file` integration.
|
||||
|
||||
### Accelerate DeepSpeed Plugin
|
||||
On your machine(s) just run:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answer the questions asked. It will ask whether you want to use a config file for DeepSpeed to which you should answer no. Then answer the following questions to generate a basic DeepSpeed config.
|
||||
This will generate a config file that will be used automatically to properly set the
|
||||
default options when doing
|
||||
|
||||
```bash
|
||||
accelerate launch my_script.py --args_to_my_script
|
||||
```
|
||||
|
||||
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with DeepSpeed Plugin:
|
||||
|
||||
**ZeRO Stage-2 DeepSpeed Plugin Example**
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: null
|
||||
main_process_port: null
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
```bash
|
||||
accelerate launch examples/nlp_example.py --mixed_precision fp16
|
||||
```
|
||||
|
||||
**ZeRO Stage-3 with CPU Offload DeepSpeed Plugin Example**
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: null
|
||||
main_process_port: null
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
```bash
|
||||
accelerate launch examples/nlp_example.py --mixed_precision fp16
|
||||
```
|
||||
|
||||
Currently, `Accelerate` supports following config through the CLI:
|
||||
|
||||
```bash
|
||||
`zero_stage`: [0] Disabled, [1] optimizer state partitioning, [2] optimizer+gradient state partitioning and [3] optimizer+gradient+parameter partitioning
|
||||
`gradient_accumulation_steps`: Number of training steps to accumulate gradients before averaging and applying them.
|
||||
`gradient_clipping`: Enable gradient clipping with value.
|
||||
`offload_optimizer_device`: [none] Disable optimizer offloading, [cpu] offload optimizer to CPU, [nvme] offload optimizer to NVMe SSD. Only applicable with ZeRO >= Stage-2.
|
||||
`offload_param_device`: [none] Disable parameter offloading, [cpu] offload parameters to CPU, [nvme] offload parameters to NVMe SSD. Only applicable with ZeRO Stage-3.
|
||||
`zero3_init_flag`: Decides whether to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with ZeRO Stage-3.
|
||||
`zero3_save_16bit_model`: Decides whether to save 16-bit model weights when using ZeRO Stage-3.
|
||||
`mixed_precision`: `no` for FP32 training, `fp16` for FP16 mixed-precision training and `bf16` for BF16 mixed-precision training.
|
||||
```
|
||||
To be able to tweak more options, you will need to use a DeepSpeed config file.
|
||||
|
||||
### DeepSpeed Config File
|
||||
On your machine(s) just run:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answer the questions asked. It will ask whether you want to use a config file for deepspeed to which you answer yes
|
||||
and provide the path to the deepspeed config file.
|
||||
This will generate a config file that will be used automatically to properly set the
|
||||
default options when doing
|
||||
|
||||
```bash
|
||||
accelerate launch my_script.py --args_to_my_script
|
||||
```
|
||||
|
||||
For instance, here is how you would run the NLP example `examples/by_feature/deepspeed_with_config_support.py` (from the root of the repo) with DeepSpeed Config File:
|
||||
|
||||
**ZeRO Stage-2 DeepSpeed Config File Example**
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: /home/ubuntu/accelerate/examples/configs/deepspeed_config_templates/zero_stage2_config.json
|
||||
zero3_init_flag: true
|
||||
distributed_type: DEEPSPEED
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: null
|
||||
main_process_port: null
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
with the contents of `zero_stage2_config.json` being:
|
||||
```json
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
accelerate launch examples/by_feature/deepspeed_with_config_support.py \
|
||||
--config_name "gpt2-large" \
|
||||
--tokenizer_name "gpt2-large" \
|
||||
--dataset_name "wikitext" \
|
||||
--dataset_config_name "wikitext-2-raw-v1" \
|
||||
--block_size 128 \
|
||||
--output_dir "./clm/clm_deepspeed_stage2_accelerate" \
|
||||
--learning_rate 5e-4 \
|
||||
--per_device_train_batch_size 24 \
|
||||
--per_device_eval_batch_size 24 \
|
||||
--num_train_epochs 3 \
|
||||
--with_tracking \
|
||||
--report_to "wandb"\
|
||||
```
|
||||
|
||||
**ZeRO Stage-3 with CPU offload DeepSpeed Config File Example**
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: /home/ubuntu/accelerate/examples/configs/deepspeed_config_templates/zero_stage3_offload_config.json
|
||||
zero3_init_flag: true
|
||||
distributed_type: DEEPSPEED
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: null
|
||||
main_process_port: null
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
use_cpu: false
|
||||
```
|
||||
with the contents of `zero_stage3_offload_config.json` being:
|
||||
```json
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto"
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
accelerate launch examples/by_feature/deepspeed_with_config_support.py \
|
||||
--config_name "gpt2-large" \
|
||||
--tokenizer_name "gpt2-large" \
|
||||
--dataset_name "wikitext" \
|
||||
--dataset_config_name "wikitext-2-raw-v1" \
|
||||
--block_size 128 \
|
||||
--output_dir "./clm/clm_deepspeed_stage3_offload_accelerate" \
|
||||
--learning_rate 5e-4 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--num_train_epochs 3 \
|
||||
--with_tracking \
|
||||
--report_to "wandb"\
|
||||
```
|
||||
|
||||
**Important code changes when using DeepSpeed Config File**
|
||||
|
||||
1. DeepSpeed Optimizers and Schedulers. For more information on these,
|
||||
see the [DeepSpeed Optimizers](https://deepspeed.readthedocs.io/en/latest/optimizers.html) and [DeepSpeed Schedulers](https://deepspeed.readthedocs.io/en/latest/schedulers.html) documentation.
|
||||
We will look at the changes needed in the code when using these.
|
||||
|
||||
a. DS Optim + DS Scheduler: The case when both `optimizer` and `scheduler` keys present in the DeepSpeed config file.
|
||||
In this situation, those will be used and user has to use `accelerate.utils.DummyOptim` and `accelerate.utils.DummyScheduler` to replace the PyTorch/Custom optimizers and schedulers in their code.
|
||||
Below is the snippet from `examples/by_feature/deepspeed_with_config_support.py` showing this:
|
||||
```python
|
||||
# Creates Dummy Optimizer if `optimizer` was spcified in the config file else creates Adam Optimizer
|
||||
optimizer_cls = (
|
||||
torch.optim.AdamW
|
||||
if accelerator.state.deepspeed_plugin is None
|
||||
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
else DummyOptim
|
||||
)
|
||||
optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
|
||||
if (
|
||||
accelerator.state.deepspeed_plugin is None
|
||||
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = DummyScheduler(
|
||||
optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps
|
||||
)
|
||||
```
|
||||
b. Custom Optim + Custom Scheduler: The case when both `optimizer` and `scheduler` keys are absent in the DeepSpeed config file.
|
||||
In this situation, no code changes are needed from the user and this is the case when using integration via DeepSpeed Plugin.
|
||||
In the above example we can see that the code reamins unchanged if the `optimizer` and `scheduler` keys are absent in the DeepSpeed config file.
|
||||
|
||||
c. Custom Optim + DS Scheduler: The case when only `scheduler` key is present in the DeepSpeed config file.
|
||||
In this situation, user has to use `accelerate.utils.DummyScheduler` to replace the PyTorch/Custom scheduler in their code.
|
||||
|
||||
d. DS Optim + Custom Scheduler: The case when only `optimizer` key is present in the DeepSpeed config file.
|
||||
This will result in an error because one can only use DS Scheduler when using DS Optim.
|
||||
|
||||
2. Notice the `auto` values in the above example DeepSpeed config files. These are automatically handled by `prepare` method
|
||||
based on model, dataloaders, dummy optimizer and dummy schedulers provided to `prepare` method.
|
||||
Only the `auto` fields specified in above examples are handled by `prepare` method and the rest have to be explicitly specified by the user.
|
||||
|
||||
## Saving and loading
|
||||
|
||||
1. Saving and loading of models is unchanged for ZeRO Stage-1 and Stage-2.
|
||||
|
||||
2. under ZeRO Stage-3, `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs.
|
||||
ZeRO Stage-3 has 2 options:
|
||||
|
||||
a. Saving the entire 16bit model weights to directly load later on using `model.load_state_dict(torch.load(pytorch_model.bin))`.
|
||||
For this, either set `zero_optimization.stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed Config file or set
|
||||
`zero3_save_16bit_model` to True in DeepSpeed Plugin.
|
||||
**Note that this option requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed.**
|
||||
Below is the snippet from `examples/by_feature/deepspeed_with_config_support.py` showing this:
|
||||
```python
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
|
||||
# New Code #
|
||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||
# The model name saved is `pytorch_model.bin`
|
||||
unwrapped_model.save_pretrained(
|
||||
args.output_dir,
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
```
|
||||
|
||||
b. To get 32bit weights, first save the model using `model.save_checkpoint()`.
|
||||
Below is the snippet from `examples/by_feature/deepspeed_with_config_support.py` showing this:
|
||||
```python
|
||||
success = model.save_checkpoint(PATH, ckpt_id, checkpoint_state_dict)
|
||||
status_msg = "checkpointing: PATH={}, ckpt_id={}".format(PATH, ckpt_id)
|
||||
if success:
|
||||
logging.info(f"Success {status_msg}")
|
||||
else:
|
||||
logging.warning(f"Failure {status_msg}")
|
||||
```
|
||||
This will create ZeRO model and optimizer partitions along with `zero_to_fp32.py` script in checkpoint directory.
|
||||
One can use this script to do offline consolidation.
|
||||
It requires no configuration files or GPUs. Here is an example of its usage:
|
||||
```bash
|
||||
$ cd /path/to/checkpoint_dir
|
||||
$ ./zero_to_fp32.py . pytorch_model.bin
|
||||
Processing zero checkpoint at global_step1
|
||||
Detected checkpoint of type zero stage 3, world_size: 2
|
||||
Saving fp32 state dict to pytorch_model.bin (total_numel=60506624)
|
||||
```
|
||||
To get 32bit model for saving/inference, one can do the following:
|
||||
```python
|
||||
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
||||
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
fp32_model = load_state_dict_from_zero_checkpoint(unwrapped_model, checkpoint_dir)
|
||||
```
|
||||
If only interested in state_dict, one can do the following:
|
||||
```python
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir)
|
||||
```
|
||||
Note that all these functions require ~2x memory (general RAM) of the size of the final checkpoint.
|
||||
|
||||
## ZeRO Inference
|
||||
DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity.
|
||||
It uses the same ZeRO protocol as training, but it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant.
|
||||
With accelerate integration, one has to just prepare model and dataloader as shown below:
|
||||
|
||||
```python
|
||||
model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
|
||||
```
|
||||
|
||||
## Few caveats to be aware of
|
||||
|
||||
1. Current integration doesn’t support Pipeline Parallelism of DeepSpeed.
|
||||
2. Current integration doesn’t support `mpu`, limiting the tensor parallelism which is supported in Megatron-LM.
|
||||
3. Current integration doesn’t support multiple models for a given `accelerator` object.
|
||||
|
||||
|
||||
## Internals
|
||||
|
||||
[[autodoc]] utils.DeepSpeedPlugin
|
||||
|
||||
[[autodoc]] utils.DummyOptim
|
||||
|
||||
[[autodoc]] utils.DummyScheduler
|
||||
|
||||
[[autodoc]] utils.DeepSpeedEngineWrapper
|
||||
|
||||
[[autodoc]] utils.DeepSpeedOptimizerWrapper
|
||||
|
||||
[[autodoc]] utils.DeepSpeedSchedulerWrapper
|
||||
|
||||
|
||||
## Main DeepSpeed Resources
|
||||
|
||||
- [Project's github](https://github.com/microsoft/deepspeed)
|
||||
- [Usage docs](https://www.deepspeed.ai/getting-started/)
|
||||
- [API docs](https://deepspeed.readthedocs.io/en/latest/index.html)
|
||||
- [Blog posts](https://www.microsoft.com/en-us/research/search/?q=deepspeed)
|
||||
|
||||
Papers:
|
||||
|
||||
- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
|
||||
- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
|
||||
- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
|
||||
|
||||
Finally, please, remember that, 🤗 `Accelerate` only integrates DeepSpeed, therefore if you
|
||||
have any problems or questions with regards to DeepSpeed usage, please, file an issue with [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/issues).
|
||||
|
736
examples/by_feature/deepspeed_with_config_support.py
Executable file
736
examples/by_feature/deepspeed_with_config_support.py
Executable file
@ -0,0 +1,736 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
|
||||
on a text file or a dataset without using HuggingFace Trainer.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=text-generation
|
||||
"""
|
||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import datasets
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DummyOptim, DummyScheduler, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split_percentage",
|
||||
default=5,
|
||||
help="The percentage of the train set used as validation set in case there's no validation split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the evaluation dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model type to use if training from scratch.",
|
||||
choices=MODEL_TYPES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional input sequence length after tokenization. The training dataset will be truncated in block of"
|
||||
" this size for training. Default to the model max input length for single sentence inputs (take into"
|
||||
" account special tokens)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of processes to use for the preprocessing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
# New Code #
|
||||
# Whether to load the best model at the end of training
|
||||
parser.add_argument(
|
||||
"--load_best_model",
|
||||
action="store_true",
|
||||
help="Whether to load the best model at the end of training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
help="Whether to enable experiment trackers for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="all",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if args.train_file is not None:
|
||||
extension = args.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
||||
if args.validation_file is not None:
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# New Code #
|
||||
def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs):
|
||||
"""Utility function for checkpointing model + optimizer dictionaries
|
||||
The main purpose for this is to be able to resume training from that instant again
|
||||
"""
|
||||
checkpoint_state_dict = {
|
||||
"epoch": epoch,
|
||||
"last_global_step": last_global_step,
|
||||
}
|
||||
# Add extra kwargs too
|
||||
checkpoint_state_dict.update(kwargs)
|
||||
|
||||
success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
|
||||
status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
|
||||
if success:
|
||||
logging.info(f"Success {status_msg}")
|
||||
else:
|
||||
logging.warning(f"Failure {status_msg}")
|
||||
return
|
||||
|
||||
|
||||
# New Code #
|
||||
def load_training_checkpoint(model, load_dir, tag=None, **kwargs):
|
||||
"""Utility function for checkpointing model + optimizer dictionaries
|
||||
The main purpose for this is to be able to resume training from that instant again
|
||||
"""
|
||||
_, checkpoint_state_dict = model.load_checkpoint(load_dir, tag=tag, **kwargs)
|
||||
epoch = checkpoint_state_dict["epoch"]
|
||||
last_global_step = checkpoint_state_dict["last_global_step"]
|
||||
del checkpoint_state_dict
|
||||
return (epoch, last_global_step)
|
||||
|
||||
|
||||
# New Code #
|
||||
def evaluate(args, model, eval_dataloader, accelerator, eval_dataset):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs.loss
|
||||
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[: len(eval_dataset)]
|
||||
try:
|
||||
eval_loss = torch.mean(losses)
|
||||
perplexity = math.exp(eval_loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return perplexity, eval_loss
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
|
||||
# in the environment
|
||||
accelerator = (
|
||||
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
|
||||
)
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
split=f"train[:{args.validation_split_percentage}%]",
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
split=f"train[{args.validation_split_percentage}%:]",
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
dataset_args = {}
|
||||
if args.train_file is not None:
|
||||
data_files["train"] = args.train_file
|
||||
if args.validation_file is not None:
|
||||
data_files["validation"] = args.validation_file
|
||||
extension = args.train_file.split(".")[-1]
|
||||
if extension == "txt":
|
||||
extension = "text"
|
||||
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[:{args.validation_split_percentage}%]",
|
||||
**dataset_args,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[{args.validation_split_percentage}%:]",
|
||||
**dataset_args,
|
||||
)
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
if args.config_name:
|
||||
config = AutoConfig.from_pretrained(args.config_name)
|
||||
elif args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
config = CONFIG_MAPPING[args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
|
||||
elif args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
if args.model_name_or_path:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
column_names = raw_datasets["train"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column_name])
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
if args.block_size is None:
|
||||
block_size = tokenizer.model_max_length
|
||||
if block_size > 1024:
|
||||
logger.warning(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
|
||||
)
|
||||
block_size = 1024
|
||||
else:
|
||||
if args.block_size > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
|
||||
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
||||
)
|
||||
block_size = min(args.block_size, tokenizer.model_max_length)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
if total_length >= block_size:
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
||||
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
||||
# to preprocess.
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
with accelerator.main_process_first():
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc=f"Grouping texts in chunks of {block_size}",
|
||||
)
|
||||
|
||||
train_dataset = lm_datasets["train"]
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
# New Code #
|
||||
# Creates Dummy Optimizer if `optimizer` was spcified in the config file else creates Adam Optimizer
|
||||
optimizer_cls = (
|
||||
torch.optim.AdamW
|
||||
if accelerator.state.deepspeed_plugin is None
|
||||
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
else DummyOptim
|
||||
)
|
||||
optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
|
||||
# New Code
|
||||
# Get gradient accumulation steps from deepspeed config if available
|
||||
if accelerator.state.deepspeed_plugin is not None:
|
||||
args.gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
else:
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# New Code #
|
||||
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
|
||||
if (
|
||||
accelerator.state.deepspeed_plugin is None
|
||||
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = DummyScheduler(
|
||||
optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
|
||||
# Figure out how many steps we should save the Accelerator states
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
if args.checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(args.checkpointing_steps)
|
||||
else:
|
||||
checkpointing_steps = None
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# We initialize the trackers only on main process because `accelerator.log`
|
||||
# only logs on main process and we don't want empty logs/runs on other processes.
|
||||
if args.with_tracking:
|
||||
if accelerator.is_main_process:
|
||||
experiment_config = vars(args)
|
||||
# TensorBoard cannot log Enums, need the raw value
|
||||
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
||||
accelerator.init_trackers("clm_no_trainer", experiment_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
completed_steps = 0
|
||||
starting_epoch = 0
|
||||
best_metric = None
|
||||
best_metric_checkpoint = None
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
# New Code #
|
||||
# Loads the DeepSpeed checkpoint from the specified path
|
||||
_, last_global_step = load_training_checkpoint(
|
||||
model,
|
||||
args.resume_from_checkpoint,
|
||||
**{"load_optimizer_states": True, "load_lr_scheduler_states": True},
|
||||
)
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
resume_step = last_global_step
|
||||
starting_epoch = resume_step // len(train_dataloader)
|
||||
resume_step -= starting_epoch * len(train_dataloader)
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
model.train()
|
||||
if args.with_tracking:
|
||||
total_loss = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# We need to skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == starting_epoch:
|
||||
if resume_step is not None and step < resume_step:
|
||||
completed_steps += 1
|
||||
continue
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
# We keep track of the loss at each epoch
|
||||
if args.with_tracking:
|
||||
total_loss += loss.detach().float()
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
if isinstance(checkpointing_steps, int):
|
||||
if completed_steps % checkpointing_steps == 0:
|
||||
output_dir = f"step_{completed_steps }"
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
perplexity, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset)
|
||||
logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"perplexity": perplexity,
|
||||
"eval_loss": eval_loss,
|
||||
"train_loss": total_loss.item() / len(train_dataloader),
|
||||
"epoch": epoch,
|
||||
"step": completed_steps,
|
||||
},
|
||||
step=completed_steps,
|
||||
)
|
||||
|
||||
# New Code #
|
||||
# Save the DeepSpeed checkpoint to the specified path
|
||||
checkpoint_model(args.output_dir, epoch, model, epoch, completed_steps)
|
||||
|
||||
# New Code #
|
||||
# Tracks the best checkpoint and best metric
|
||||
if best_metric is None or best_metric > perplexity:
|
||||
best_metric = perplexity
|
||||
best_metric_checkpoint = os.path.join(args.output_dir, str(epoch))
|
||||
accelerator.print(f"New best metric: {best_metric} at epoch {epoch}")
|
||||
accelerator.print(f"best_metric_checkpoint: {best_metric_checkpoint}")
|
||||
|
||||
# New Code #
|
||||
# Loads the best checkpoint after the training is finished
|
||||
if args.load_best_model:
|
||||
_, last_global_step = load_training_checkpoint(
|
||||
model,
|
||||
"/".join(best_metric_checkpoint.split("/")[:-1]),
|
||||
tag=best_metric_checkpoint.split("/")[-1],
|
||||
**{"load_optimizer_states": True, "load_lr_scheduler_states": True},
|
||||
)
|
||||
|
||||
# New Code #
|
||||
# Evaluates using the best checkpoint
|
||||
perplexity, eval_loss = evaluate(args, model, eval_dataloader, accelerator, eval_dataset)
|
||||
logger.info(f"Best model metrics: perplexity: {perplexity} eval_loss: {eval_loss}")
|
||||
if perplexity != best_metric:
|
||||
raise AssertionError(
|
||||
f"Best metric {best_metric} does not match the metric {perplexity} of the loaded best model."
|
||||
)
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
|
||||
# New Code #
|
||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||
# The model name saved is `pytorch_model.bin`
|
||||
unwrapped_model.save_pretrained(
|
||||
args.output_dir,
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"perplexity": perplexity, "eval_loss": eval_loss.item()}, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
43
examples/deepspeed_config_templates/zero_stage1_config.json
Normal file
43
examples/deepspeed_config_templates/zero_stage1_config.json
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
43
examples/deepspeed_config_templates/zero_stage2_config.json
Normal file
43
examples/deepspeed_config_templates/zero_stage2_config.json
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
44
examples/deepspeed_config_templates/zero_stage3_config.json
Normal file
44
examples/deepspeed_config_templates/zero_stage3_config.json
Normal file
@ -0,0 +1,44 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto"
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto"
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
4
setup.py
4
setup.py
@ -27,7 +27,9 @@ extras["test"] = [
|
||||
"evaluate",
|
||||
"transformers",
|
||||
"scipy",
|
||||
"sklearn"
|
||||
"sklearn",
|
||||
"parameterized",
|
||||
"deepspeed",
|
||||
]
|
||||
|
||||
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard"]
|
||||
|
@ -13,10 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -39,12 +41,14 @@ from .utils import (
|
||||
LoggerType,
|
||||
PrecisionType,
|
||||
RNGType,
|
||||
compare_versions,
|
||||
convert_outputs_to_fp32,
|
||||
extract_model_from_parallel,
|
||||
gather,
|
||||
get_pretty_name,
|
||||
is_deepspeed_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
pad_across_processes,
|
||||
reduce,
|
||||
save,
|
||||
@ -55,7 +59,13 @@ from .utils import (
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
from .utils import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper
|
||||
from .utils import (
|
||||
DeepSpeedEngineWrapper,
|
||||
DeepSpeedOptimizerWrapper,
|
||||
DeepSpeedSchedulerWrapper,
|
||||
DummyOptim,
|
||||
DummyScheduler,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -163,6 +173,30 @@ class Accelerator:
|
||||
deepspeed_plugin, DeepSpeedPlugin
|
||||
), "`deepspeed_plugin` must be a DeepSpeedPlugin object."
|
||||
os.environ["USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
|
||||
if deepspeed_plugin:
|
||||
if not is_deepspeed_available():
|
||||
raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
|
||||
if compare_versions("deepspeed", "<", "0.6.5"):
|
||||
raise ImportError("DeepSpeed version must be >= 0.6.5. Please update DeepSpeed.")
|
||||
if os.environ.get("DEEPSPEED_ZERO3_INIT", "false") == "true" or deepspeed_plugin.zero3_init_flag:
|
||||
if not is_transformers_available():
|
||||
raise Exception(
|
||||
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
|
||||
"Please run `pip install transformers`."
|
||||
)
|
||||
from transformers.deepspeed import HfDeepSpeedConfig
|
||||
|
||||
ds_config = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
del ds_config["train_batch_size"]
|
||||
ds_config.update({"train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1})
|
||||
mixed_precision = (
|
||||
os.environ.get("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
|
||||
)
|
||||
if mixed_precision == "fp16":
|
||||
ds_config.update({"fp16": {"enabled": True}})
|
||||
elif mixed_precision == "bf16":
|
||||
ds_config.update({"bf16": {"enabled": True}})
|
||||
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
|
||||
|
||||
if fsdp_plugin is None: # init from env variables
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin() if os.environ.get("USE_FSDP", "false") == "true" else None
|
||||
@ -497,9 +531,15 @@ class Accelerator:
|
||||
def _prepare_deepspeed(self, *args):
|
||||
|
||||
deepspeed_plugin = self.state.deepspeed_plugin
|
||||
self.deepspeed_config = deepspeed_plugin.deepspeed_config
|
||||
|
||||
result = [
|
||||
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
|
||||
for obj in args
|
||||
]
|
||||
|
||||
batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
|
||||
if self.split_batches:
|
||||
batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
|
||||
if len(batch_sizes) == 0:
|
||||
raise ValueError(
|
||||
"You must specify a training or evaluation dataloader in `accelerate.prepare()` when using DeepSpeed."
|
||||
@ -508,73 +548,141 @@ class Accelerator:
|
||||
batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
|
||||
if len(batch_sizes) > 1:
|
||||
logger.info(
|
||||
f"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here \
|
||||
{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
|
||||
"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
|
||||
f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
|
||||
)
|
||||
|
||||
self.deepspeed_config["train_batch_size"] = (
|
||||
batch_size_per_device * deepspeed_plugin.gradient_accumulation_steps * self.num_processes
|
||||
)
|
||||
|
||||
result = [
|
||||
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
|
||||
for obj in args
|
||||
]
|
||||
config_kwargs = {
|
||||
"train_micro_batch_size_per_gpu": batch_size_per_device,
|
||||
"train_batch_size": batch_size_per_device
|
||||
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
|
||||
* self.num_processes,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
|
||||
}
|
||||
|
||||
model = None
|
||||
optimizer = None
|
||||
scheduler = None
|
||||
for obj in result:
|
||||
if isinstance(obj, torch.nn.Module):
|
||||
model = obj
|
||||
elif isinstance(obj, (torch.optim.Optimizer, dict)):
|
||||
elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)):
|
||||
optimizer = obj
|
||||
elif (isinstance(obj, (torch.optim.lr_scheduler._LRScheduler, DummyScheduler))) or (
|
||||
type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
|
||||
):
|
||||
scheduler = obj
|
||||
|
||||
if deepspeed_plugin.auto_opt_mapping:
|
||||
is_adam = isinstance(optimizer, torch.optim.Adam)
|
||||
is_adamw = isinstance(optimizer, torch.optim.AdamW)
|
||||
if (is_adam or is_adamw) and deepspeed_plugin.offload_optimizer_device == "cpu":
|
||||
defaults = optimizer.defaults
|
||||
params = []
|
||||
for group in optimizer.param_groups:
|
||||
params.extend(group["params"])
|
||||
|
||||
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
||||
params,
|
||||
lr=defaults["lr"],
|
||||
bias_correction=True,
|
||||
betas=defaults["betas"],
|
||||
eps=defaults["eps"],
|
||||
weight_decay=defaults["weight_decay"],
|
||||
amsgrad=defaults["amsgrad"],
|
||||
adamw_mode=is_adamw,
|
||||
if optimizer is not None:
|
||||
if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)):
|
||||
raise ValueError(
|
||||
"You cannot specify an optimizer in the config file and in the code at the same time. "
|
||||
"Please remove the optimizer from the config file or "
|
||||
"create `accelerate.utils.DummyOptim` in the code."
|
||||
)
|
||||
elif "optimizer" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)):
|
||||
raise ValueError(
|
||||
"You cannot create a `DummyOptim` without specifying an optimizer in the config file."
|
||||
)
|
||||
|
||||
if isinstance(optimizer, (torch.optim.Optimizer)):
|
||||
deepspeed_plugin.deepspeed_config["zero_allow_untested_optimizer"] = True
|
||||
|
||||
if scheduler is not None:
|
||||
if "scheduler" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)):
|
||||
raise ValueError(
|
||||
"You cannot specify a scheduler in the config file and in the code at the same time. "
|
||||
"Please remove the scheduler from the config file or "
|
||||
"create `accelerate.utils.DummyScheduler` in the code."
|
||||
)
|
||||
elif "scheduler" not in deepspeed_plugin.deepspeed_config and isinstance(scheduler, (DummyScheduler)):
|
||||
raise ValueError(
|
||||
"You cannot create a `DummyScheduler` without specifying a scheduler in the config file."
|
||||
)
|
||||
|
||||
if optimizer is not None and scheduler is not None:
|
||||
if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)):
|
||||
raise ValueError(
|
||||
"You can only specify `accelerate.utils.DummyScheduler` in the code when using "
|
||||
"`accelerate.utils.DummyOptim`."
|
||||
)
|
||||
|
||||
# useful when only eval_dataloader is given into `accelerator.prepare()`
|
||||
if model is not None:
|
||||
engine = DeepSpeedEngineWrapper(
|
||||
args=None,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
config_params=self.deepspeed_config,
|
||||
dist_init_required=False,
|
||||
)
|
||||
if hasattr(model, "config") and hasattr(model.config, "hidden_size"):
|
||||
hidden_size = model.config.hidden_size
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(optimizer, (DummyOptim)):
|
||||
config_kwargs.update(
|
||||
{"optimizer.params.lr": optimizer.lr, "optimizer.params.weight_decay": optimizer.weight_decay}
|
||||
)
|
||||
if isinstance(scheduler, (DummyScheduler)):
|
||||
config_kwargs.update(
|
||||
{
|
||||
"scheduler.params.warmup_min_lr": 0,
|
||||
"scheduler.params.warmup_max_lr": scheduler.optimizer.lr,
|
||||
"scheduler.params.warmup_num_steps": scheduler.warmup_num_steps,
|
||||
}
|
||||
)
|
||||
if scheduler.total_num_steps is not None:
|
||||
config_kwargs["scheduler.params.total_num_steps"] = (
|
||||
math.ceil(scheduler.total_num_steps / self.num_processes)
|
||||
if not self.split_batches
|
||||
else scheduler.total_num_steps
|
||||
)
|
||||
deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs)
|
||||
self.deepspeed_config = deepspeed_plugin.deepspeed_config
|
||||
kwargs = dict(model=model, config_params=self.deepspeed_config)
|
||||
if optimizer is not None:
|
||||
if isinstance(optimizer, (DummyOptim)):
|
||||
kwargs["model_parameters"] = optimizer.params
|
||||
else:
|
||||
kwargs["optimizer"] = optimizer
|
||||
if scheduler is not None:
|
||||
if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:
|
||||
kwargs["lr_scheduler"] = scheduler
|
||||
|
||||
engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
|
||||
if optimizer is not None:
|
||||
optimizer = DeepSpeedOptimizerWrapper(optimizer)
|
||||
if scheduler is not None:
|
||||
if lr_scheduler is None:
|
||||
scheduler = AcceleratedScheduler(
|
||||
scheduler,
|
||||
optimizer,
|
||||
step_with_optimizer=self.step_scheduler_with_optimizer,
|
||||
split_batches=self.split_batches,
|
||||
)
|
||||
else:
|
||||
scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer)
|
||||
|
||||
for i in range(len(result)):
|
||||
if isinstance(result[i], torch.nn.Module):
|
||||
result[i] = engine
|
||||
elif isinstance(result[i], torch.optim.Optimizer):
|
||||
result[i] = DeepSpeedOptimizerWrapper(engine.optimizer, engine)
|
||||
self.deepspeed_engine = engine # pointing for deepspeed_engine.backward()
|
||||
elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)):
|
||||
result[i] = optimizer
|
||||
elif (isinstance(result[i], (torch.optim.lr_scheduler._LRScheduler, DummyScheduler))) or (
|
||||
type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
|
||||
):
|
||||
result[i] = scheduler
|
||||
# pointing for deepspeed_engine_wrapped.backward()
|
||||
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
|
||||
self._models.append(engine)
|
||||
self._optimizers.append(engine.optimizer)
|
||||
assert (
|
||||
len(self._models) == 1
|
||||
), "You can't use same `Accelerator()` instance with 2 models when using DeepSpeed"
|
||||
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
assert hasattr(
|
||||
self, "deepspeed_engine"
|
||||
), "You need to pass the model along the optimizer when using Deepspeed."
|
||||
|
||||
if optimizer is not None:
|
||||
self._optimizers.append(optimizer)
|
||||
if scheduler is not None:
|
||||
self._schedulers.append(scheduler)
|
||||
if len(self._models) > 1:
|
||||
raise AssertionError(
|
||||
"You can't use same `Accelerator()` instance with multiple models when using DeepSpeed"
|
||||
)
|
||||
return tuple(result)
|
||||
|
||||
def prepare_data_loader(self, data_loader):
|
||||
@ -612,7 +720,7 @@ class Accelerator:
|
||||
Use `accelerator.backward(loss)` in lieu of `loss.backward()`.
|
||||
"""
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
self.deepspeed_engine.backward(loss, **kwargs)
|
||||
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
|
||||
elif self.scaler is not None:
|
||||
self.scaler.scale(loss).backward(**kwargs)
|
||||
else:
|
||||
@ -648,6 +756,9 @@ class Accelerator:
|
||||
if parameters == [p for p in model.parameters()]:
|
||||
model.clip_grad_norm_(max_norm, norm_type)
|
||||
return
|
||||
elif self.distributed_type == DistributedType.DEEPSPEED:
|
||||
# `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
|
||||
return
|
||||
self.unscale_gradients()
|
||||
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
|
||||
|
||||
@ -655,6 +766,8 @@ class Accelerator:
|
||||
"""
|
||||
Should be used in place of `torch.nn.utils.clip_grad_value_`.
|
||||
"""
|
||||
if self.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP]:
|
||||
raise Exception("DeepSpeed and FSDP do not support `clip_grad_value_`. Use `clip_grad_norm_` instead.")
|
||||
self.unscale_gradients()
|
||||
torch.nn.utils.clip_grad_value_(parameters, clip_value)
|
||||
|
||||
@ -837,7 +950,7 @@ class Accelerator:
|
||||
self._schedulers = []
|
||||
self._optimizers = []
|
||||
self._models = []
|
||||
self.deepspeed_engine = None
|
||||
self.deepspeed_engine_wrapped = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -875,12 +988,19 @@ class Accelerator:
|
||||
|
||||
def get_state_dict(self, model):
|
||||
is_zero_3 = False
|
||||
if is_deepspeed_available():
|
||||
if isinstance(model, DeepSpeedEngineWrapper) and self.distributed_type == DistributedType.DEEPSPEED:
|
||||
is_zero_3 = self.state.deepspeed_plugin.zero_stage == 3
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
is_zero_3 = self.deepspeed_config["zero_optimization"]["stage"] == 3
|
||||
|
||||
if is_zero_3:
|
||||
state_dict = model._zero3_consolidated_16bit_state_dict()
|
||||
if model.zero_gather_16bit_weights_on_model_save():
|
||||
state_dict = model._zero3_consolidated_16bit_state_dict()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
||||
)
|
||||
else:
|
||||
model = self.unwrap_model(model)
|
||||
state_dict = model.state_dict()
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import ComputeEnvironment, DistributedType, is_deepspeed_available
|
||||
from ...utils import ComputeEnvironment, DistributedType, is_deepspeed_available, is_transformers_available
|
||||
from .config_args import ClusterConfig
|
||||
from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool
|
||||
|
||||
@ -77,24 +77,72 @@ def get_cluster_input():
|
||||
), "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
|
||||
|
||||
if distributed_type == DistributedType.DEEPSPEED:
|
||||
deepspeed_config["zero_stage"] = _ask_field(
|
||||
"What should be your DeepSpeed's ZeRO optimization stage (0, 1, 2, 3)? [2]: ",
|
||||
lambda x: int(x),
|
||||
default=2,
|
||||
use_deepspeed_config = _ask_field(
|
||||
"Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
if deepspeed_config["zero_stage"] >= 2:
|
||||
deepspeed_config["offload_optimizer_device"] = _ask_field(
|
||||
"Where to offload optimizer states? [NONE/cpu/nvme]: ",
|
||||
if use_deepspeed_config:
|
||||
deepspeed_config["deepspeed_config_file"] = _ask_field(
|
||||
"Please enter the path to the json DeepSpeed config file: ",
|
||||
lambda x: str(x),
|
||||
default="none",
|
||||
)
|
||||
else:
|
||||
deepspeed_config["zero_stage"] = _ask_field(
|
||||
"What should be your DeepSpeed's ZeRO optimization stage (0, 1, 2, 3)? [2]: ",
|
||||
lambda x: int(x),
|
||||
default=2,
|
||||
)
|
||||
|
||||
deepspeed_config["gradient_accumulation_steps"] = _ask_field(
|
||||
"How many gradient accumulation steps you're passing in your script? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
if deepspeed_config["zero_stage"] >= 2:
|
||||
deepspeed_config["offload_optimizer_device"] = _ask_field(
|
||||
"Where to offload optimizer states? [none/cpu/nvme]: ",
|
||||
lambda x: str(x),
|
||||
default="none",
|
||||
)
|
||||
deepspeed_config["offload_param_device"] = _ask_field(
|
||||
"Where to offload parameters? [none/cpu/nvme]: ",
|
||||
lambda x: str(x),
|
||||
default="none",
|
||||
)
|
||||
deepspeed_config["gradient_accumulation_steps"] = _ask_field(
|
||||
"How many gradient accumulation steps you're passing in your script? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
)
|
||||
use_gradient_clipping = _ask_field(
|
||||
"Do you want to use gradient clipping? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
if use_gradient_clipping:
|
||||
deepspeed_config["gradient_clipping"] = _ask_field(
|
||||
"What is the gradient clipping value? [1.0]: ",
|
||||
lambda x: float(x),
|
||||
default=1.0,
|
||||
)
|
||||
if deepspeed_config["zero_stage"] == 3:
|
||||
deepspeed_config["zero3_save_16bit_model"] = _ask_field(
|
||||
"Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
deepspeed_config["zero3_init_flag"] = _ask_field(
|
||||
"Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
if deepspeed_config["zero3_init_flag"]:
|
||||
if not is_transformers_available():
|
||||
raise Exception(
|
||||
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
|
||||
"Please run `pip3 install transformers`."
|
||||
)
|
||||
|
||||
fsdp_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU]:
|
||||
@ -155,11 +203,14 @@ def get_cluster_input():
|
||||
num_processes = 1
|
||||
|
||||
if distributed_type != DistributedType.TPU:
|
||||
mixed_precision = _ask_field(
|
||||
"Do you wish to use FP16 or BF16 (mixed precision)? [NO/fp16/bf16]: ",
|
||||
lambda x: str(x).lower(),
|
||||
default="no",
|
||||
)
|
||||
if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
|
||||
mixed_precision = "no"
|
||||
else:
|
||||
mixed_precision = _ask_field(
|
||||
"Do you wish to use FP16 or BF16 (mixed precision)? [NO/fp16/bf16]: ",
|
||||
lambda x: str(x).lower(),
|
||||
default="no",
|
||||
)
|
||||
else:
|
||||
mixed_precision = "no"
|
||||
|
||||
|
@ -31,6 +31,7 @@ from accelerate.utils import (
|
||||
DistributedType,
|
||||
PrecisionType,
|
||||
PrepareForLaunch,
|
||||
is_deepspeed_available,
|
||||
is_sagemaker_available,
|
||||
)
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
@ -57,6 +58,56 @@ def launch_command_parser(subparsers=None):
|
||||
action="store_true",
|
||||
help="Whether to use deepspeed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepspeed_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="DeepSpeed config file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_stage",
|
||||
default=None,
|
||||
type=int,
|
||||
help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_device",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_device",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
default=None,
|
||||
type=int,
|
||||
help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_clipping",
|
||||
default=None,
|
||||
type=float,
|
||||
help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_init_flag",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
|
||||
"Only applicable with DeepSpeed ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_save_16bit_model",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
|
||||
"Only applicable with DeepSpeed ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fsdp",
|
||||
default=False,
|
||||
@ -158,24 +209,6 @@ def launch_command_parser(subparsers=None):
|
||||
"script."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_stage",
|
||||
default=None,
|
||||
type=int,
|
||||
help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_device",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
default=None,
|
||||
type=int,
|
||||
help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed).",
|
||||
)
|
||||
|
||||
# Other arguments of the training scripts
|
||||
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
|
||||
@ -279,6 +312,8 @@ def multi_gpu_launcher(args):
|
||||
|
||||
|
||||
def deepspeed_launcher(args):
|
||||
if not is_deepspeed_available():
|
||||
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
|
||||
cmd = ["deepspeed", "--no_local_rank"]
|
||||
if args.num_machines > 1:
|
||||
cmd.extend(
|
||||
@ -323,7 +358,12 @@ def deepspeed_launcher(args):
|
||||
current_env["USE_DEEPSPEED"] = "true"
|
||||
current_env["DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage)
|
||||
current_env["GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps)
|
||||
current_env["GRADIENT_CLIPPING"] = str(args.gradient_clipping)
|
||||
current_env["DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower()
|
||||
current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower()
|
||||
current_env["DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower()
|
||||
current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower()
|
||||
current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file).lower()
|
||||
|
||||
process = subprocess.Popen(cmd, env=current_env)
|
||||
process.wait()
|
||||
|
@ -108,10 +108,18 @@ class AcceleratorState:
|
||||
mixed_precision = (
|
||||
parse_choice_from_env("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
|
||||
)
|
||||
if mixed_precision == "fp16":
|
||||
if (
|
||||
mixed_precision == "fp16"
|
||||
and "fp16" not in deepspeed_plugin.deepspeed_config
|
||||
and "bf16" not in deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
deepspeed_plugin.deepspeed_config.update({"fp16": {"enabled": True}})
|
||||
elif mixed_precision == "bf16":
|
||||
deepspeed_plugin.deepspeed_config.update({"bfloat16": {"enabled": True}})
|
||||
elif (
|
||||
mixed_precision == "bf16"
|
||||
and "fp16" not in deepspeed_plugin.deepspeed_config
|
||||
and "bf16" not in deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
deepspeed_plugin.deepspeed_config.update({"bf16": {"enabled": True}})
|
||||
self.deepspeed_plugin = deepspeed_plugin
|
||||
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
|
||||
self.distributed_type = DistributedType.MULTI_GPU
|
||||
@ -189,10 +197,11 @@ class AcceleratorState:
|
||||
f"Process index: {self.process_index}\n"
|
||||
f"Local process index: {self.local_process_index}\n"
|
||||
f"Device: {self.device}\n"
|
||||
f"Mixed precision type: {mixed_precision}\n"
|
||||
)
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
|
||||
else:
|
||||
f"Mixed precision type: {mixed_precision}\n"
|
||||
return repr
|
||||
|
||||
# For backward compatibility
|
||||
|
@ -26,7 +26,14 @@ from unittest import mock
|
||||
import torch
|
||||
|
||||
from ..state import AcceleratorState
|
||||
from ..utils import gather, is_comet_ml_available, is_tensorboard_available, is_tpu_available, is_wandb_available
|
||||
from ..utils import (
|
||||
gather,
|
||||
is_comet_ml_available,
|
||||
is_deepspeed_available,
|
||||
is_tensorboard_available,
|
||||
is_tpu_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
@ -85,6 +92,13 @@ def require_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_deepspeed(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed
|
||||
"""
|
||||
return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case)
|
||||
|
||||
|
||||
def require_tensorboard(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't
|
||||
|
@ -27,6 +27,7 @@ from .imports import (
|
||||
is_sagemaker_available,
|
||||
is_tensorboard_available,
|
||||
is_tpu_available,
|
||||
is_transformers_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from .modeling import (
|
||||
@ -76,7 +77,13 @@ from .versions import compare_versions, is_torch_version
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
from .deepspeed import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper
|
||||
from .deepspeed import (
|
||||
DeepSpeedEngineWrapper,
|
||||
DeepSpeedOptimizerWrapper,
|
||||
DeepSpeedSchedulerWrapper,
|
||||
DummyOptim,
|
||||
DummyScheduler,
|
||||
)
|
||||
|
||||
from .launch import PrepareForLaunch
|
||||
from .memory import find_executable_batch_size
|
||||
|
@ -19,8 +19,11 @@ General namespace and dataclass related classes
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import Callable, Iterable, Optional
|
||||
@ -208,10 +211,15 @@ class TensorInformation:
|
||||
|
||||
@dataclass
|
||||
class DeepSpeedPlugin:
|
||||
"""
|
||||
This plugin is used to integrate DeepSpeed.
|
||||
"""
|
||||
|
||||
config_file: str = field(default=None, metadata={"help": "Path to the DeepSpeed config file."})
|
||||
gradient_accumulation_steps: int = field(
|
||||
default=None, metadata={"help": "Number of steps to accumulate gradients before updating optimizer states"}
|
||||
)
|
||||
gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"})
|
||||
zero_stage: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"},
|
||||
@ -220,37 +228,137 @@ class DeepSpeedPlugin:
|
||||
default=True,
|
||||
metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"},
|
||||
)
|
||||
|
||||
auto_opt_mapping: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to map torch.adam to deepspeed optimizer version of adam based on config"},
|
||||
offload_optimizer_device: bool = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."},
|
||||
)
|
||||
offload_param_device: bool = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."},
|
||||
)
|
||||
zero3_init_flag: bool = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||
"Only applicable with ZeRO Stage-3."
|
||||
},
|
||||
)
|
||||
zero3_save_16bit_model: bool = field(
|
||||
default=None,
|
||||
metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."},
|
||||
)
|
||||
|
||||
offload_optimizer_device: bool = field(default=None, metadata={"help": "Possible options are none|cpu|nvme"})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.config_file is None:
|
||||
self.config_file = os.environ.get("DEEPSPEED_CONFIG_FILE", "none")
|
||||
if self.config_file != "none":
|
||||
with io.open(self.config_file, "r", encoding="utf-8") as f:
|
||||
self.deepspeed_config = json.load(f)
|
||||
if "gradient_accumulation_steps" not in self.deepspeed_config:
|
||||
self.deepspeed_config["gradient_accumulation_steps"] = 1
|
||||
elif self.deepspeed_config["gradient_accumulation_steps"] == "auto":
|
||||
raise ValueError("gradient_accumulation_steps cannot be set to 'auto' in the DeepSpeed config file.")
|
||||
if "zero_optimization" not in self.deepspeed_config:
|
||||
raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config file.")
|
||||
else:
|
||||
if self.gradient_accumulation_steps is None:
|
||||
self.gradient_accumulation_steps = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", 1))
|
||||
|
||||
if self.gradient_accumulation_steps is None:
|
||||
self.gradient_accumulation_steps = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", 1))
|
||||
if self.gradient_clipping is None:
|
||||
gradient_clipping = os.environ.get("GRADIENT_CLIPPING", "none")
|
||||
if gradient_clipping != "none":
|
||||
self.gradient_clipping = float(gradient_clipping)
|
||||
|
||||
if self.zero_stage is None:
|
||||
self.zero_stage = int(os.environ.get("DEEPSPEED_ZERO_STAGE", 2))
|
||||
if self.zero_stage is None:
|
||||
self.zero_stage = int(os.environ.get("DEEPSPEED_ZERO_STAGE", 2))
|
||||
|
||||
if self.offload_optimizer_device is None:
|
||||
self.offload_optimizer_device = os.environ.get("DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none")
|
||||
if self.offload_optimizer_device is None:
|
||||
self.offload_optimizer_device = os.environ.get("DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none")
|
||||
|
||||
self.deepspeed_config = {
|
||||
"train_batch_size": None,
|
||||
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_stage,
|
||||
"offload_optimizer": {
|
||||
"device": self.offload_optimizer_device,
|
||||
if self.offload_param_device is None:
|
||||
self.offload_param_device = os.environ.get("DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none")
|
||||
|
||||
if self.zero3_save_16bit_model is None:
|
||||
self.zero3_save_16bit_model = os.environ.get("DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true"
|
||||
|
||||
self.deepspeed_config = {
|
||||
"train_batch_size": "auto",
|
||||
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_stage,
|
||||
"offload_optimizer": {
|
||||
"device": self.offload_optimizer_device,
|
||||
},
|
||||
"offload_param": {
|
||||
"device": self.offload_param_device,
|
||||
},
|
||||
"stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model,
|
||||
},
|
||||
},
|
||||
"steps_per_print": float("inf"), # this will stop deepspeed from logging @ stdout
|
||||
"zero_allow_untested_optimizer": True,
|
||||
}
|
||||
}
|
||||
if self.gradient_clipping:
|
||||
self.deepspeed_config["gradient_clipping"] = self.gradient_clipping
|
||||
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
|
||||
if self.zero3_init_flag is None:
|
||||
self.zero3_init_flag = os.environ.get("DEEPSPEED_ZERO3_INIT", "false") == "true"
|
||||
if self.zero3_init_flag and self.deepspeed_config["zero_optimization"]["stage"] != 3:
|
||||
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
|
||||
self.zero3_init_flag = False
|
||||
|
||||
def find_config_node(self, ds_key_long):
|
||||
config = self.deepspeed_config
|
||||
|
||||
# find the config node of interest if it exists
|
||||
nodes = ds_key_long.split(".")
|
||||
ds_key = nodes.pop()
|
||||
for node in nodes:
|
||||
config = config.get(node)
|
||||
if config is None:
|
||||
return None, ds_key
|
||||
|
||||
return config, ds_key
|
||||
|
||||
def fill_match(self, ds_key_long, mismatches, must_match=True, **kwargs):
|
||||
config, ds_key = self.find_config_node(ds_key_long)
|
||||
if config is None:
|
||||
return
|
||||
|
||||
if config.get(ds_key) == "auto":
|
||||
if ds_key_long in kwargs:
|
||||
config[ds_key] = kwargs[ds_key_long]
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{ds_key_long}` not found in kwargs. "
|
||||
f"Please specify `{ds_key_long}` without `auto`(set to correct value) in the DeepSpeed config file or "
|
||||
"pass it in kwargs."
|
||||
)
|
||||
|
||||
if not must_match:
|
||||
return
|
||||
|
||||
ds_val = config.get(ds_key)
|
||||
if ds_val is not None and ds_key_long in kwargs:
|
||||
if ds_val != kwargs[ds_key_long]:
|
||||
mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}")
|
||||
|
||||
def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs):
|
||||
"""Process the DeepSpeed config with the values from the kwargs."""
|
||||
mismatches = [] if mismatches is None else mismatches
|
||||
if config is None:
|
||||
config = self.deepspeed_config
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict):
|
||||
self.deepspeed_config_process(
|
||||
prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs
|
||||
)
|
||||
else:
|
||||
self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)
|
||||
if len(mismatches) > 0 and prefix == "":
|
||||
mismatches_msg = "\n".join(mismatches)
|
||||
raise ValueError(
|
||||
"Please correct the following DeepSpeed config values that mismatch kwargs "
|
||||
f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -12,58 +12,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from accelerate.scheduler import AcceleratedScheduler
|
||||
|
||||
from ..optimizer import AcceleratedOptimizer
|
||||
from .imports import is_apex_available, is_deepspeed_available
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
||||
class DeepSpeedEngineWrapper(DeepSpeedEngine):
|
||||
class DeepSpeedEngineWrapper:
|
||||
"""
|
||||
Wrapper over deepspeed.DeepSpeedEngine object
|
||||
Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.
|
||||
|
||||
Args:
|
||||
engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# overwriting micro_steps for user's gradient_accumulation
|
||||
self.micro_steps = -1
|
||||
|
||||
def step(self, lr_kwargs=None):
|
||||
"""DeepSpeedEngine.step() without `micro_steps` update & no profiling"""
|
||||
if self.is_gradient_accumulation_boundary(): # it shouldn't matter whether we keep this line or not
|
||||
if self.progressive_layer_drop:
|
||||
self.progressive_layer_drop.update_state(self.global_steps)
|
||||
|
||||
self._take_model_step(lr_kwargs)
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def backward(self, loss):
|
||||
"""DeepSpeedEngine.backward() with with no loss scaling; no profiling but with `micro_steps` update"""
|
||||
# runs backpropagation and handles mixed precision
|
||||
self.engine.backward(loss)
|
||||
|
||||
if self.zero_optimization():
|
||||
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
|
||||
self.optimizer.backward(loss)
|
||||
elif self.amp_enabled():
|
||||
# AMP requires delaying unscale when inside gradient accumulation boundaries
|
||||
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
|
||||
delay_unscale = not self.is_gradient_accumulation_boundary()
|
||||
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.fp16_enabled():
|
||||
self.optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if self.enable_backward_allreduce:
|
||||
self.allreduce_gradients()
|
||||
|
||||
# this will ensure deepspeed gradient_accumulation matches user's accumulation
|
||||
self.micro_steps += 1
|
||||
# deepspeed `engine.step` performs following operations:
|
||||
# gradient accumulation check
|
||||
# gradient clipping
|
||||
# optimizer step
|
||||
# zero grad
|
||||
# checking overflow
|
||||
# lr_scheduler step
|
||||
self.engine.step()
|
||||
|
||||
|
||||
class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
|
||||
@ -75,22 +51,79 @@ class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
|
||||
The optimizer to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, model: DeepSpeedEngineWrapper):
|
||||
def __init__(self, optimizer):
|
||||
super().__init__(optimizer, device_placement=False, scaler=None)
|
||||
|
||||
self.model = model
|
||||
|
||||
def zero_grad(self, set_to_none=None):
|
||||
pass # `model.step()` is doing that automatically. Therefore, it's implementation is not needed
|
||||
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
|
||||
|
||||
def step(self):
|
||||
"""This will handle optimizer.step() & optimizer.zero_grad() with gradient_accumulation"""
|
||||
self.model.step()
|
||||
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
|
||||
|
||||
@property
|
||||
def is_overflow(self):
|
||||
def step_was_skipped(self):
|
||||
"""Whether or not the optimizer step was done, or skipped because of gradient overflow."""
|
||||
overflow = False
|
||||
if hasattr(self.optimizer, "overflow"):
|
||||
overflow = self.optimizer.overflow
|
||||
return overflow
|
||||
return self.optimizer.overflow
|
||||
|
||||
|
||||
class DeepSpeedSchedulerWrapper(AcceleratedScheduler):
|
||||
"""
|
||||
Internal wrapper around a deepspeed scheduler.
|
||||
|
||||
Args:
|
||||
scheduler (`torch.optim.lr_scheduler.LambdaLR`):
|
||||
The scheduler to wrap.
|
||||
optimizers (one or a list of `torch.optim.Optimizer`):
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler, optimizers):
|
||||
super().__init__(scheduler, optimizers)
|
||||
|
||||
def step(self):
|
||||
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
|
||||
|
||||
|
||||
class DummyOptim:
|
||||
"""
|
||||
Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training
|
||||
loop when optimizer config is specified in the deepspeed config file.
|
||||
|
||||
Args:
|
||||
lr (float):
|
||||
Learning rate.
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
weight_decay (float):
|
||||
Weight decay.
|
||||
**kwargs:
|
||||
Other arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):
|
||||
self.params = params
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyScheduler:
|
||||
"""
|
||||
Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
|
||||
loop when scheduler config is specified in the deepspeed config file.
|
||||
|
||||
Args:
|
||||
optimizer (`torch.optim.optimizer.Optimizer`):
|
||||
The optimizer to wrap.
|
||||
total_num_steps (int):
|
||||
Total number of steps.
|
||||
warmup_num_steps (int):
|
||||
Number of steps for warmup.
|
||||
**kwargs:
|
||||
Other arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs):
|
||||
self.optimizer = optimizer
|
||||
self.total_num_steps = total_num_steps
|
||||
self.warmup_num_steps = warmup_num_steps
|
||||
self.kwargs = kwargs
|
||||
|
@ -63,6 +63,10 @@ def is_deepspeed_available():
|
||||
return False
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
def is_tensorboard_available():
|
||||
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
|
||||
|
||||
|
49
tests/deepspeed/ds_config_zero2.json
Normal file
49
tests/deepspeed/ds_config_zero2.json
Normal file
@ -0,0 +1,49 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
56
tests/deepspeed/ds_config_zero3.json
Normal file
56
tests/deepspeed/ds_config_zero3.json
Normal file
@ -0,0 +1,56 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto"
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
581
tests/deepspeed/test_deepspeed.py
Normal file
581
tests/deepspeed/test_deepspeed.py
Normal file
@ -0,0 +1,581 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.scheduler import AcceleratedScheduler
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.test_utils.testing import require_cuda, require_deepspeed
|
||||
from accelerate.test_utils.training import RegressionDataset
|
||||
from accelerate.utils.dataclasses import DeepSpeedPlugin
|
||||
from accelerate.utils.deepspeed import (
|
||||
DeepSpeedEngineWrapper,
|
||||
DeepSpeedOptimizerWrapper,
|
||||
DeepSpeedSchedulerWrapper,
|
||||
DummyOptim,
|
||||
DummyScheduler,
|
||||
)
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModel, AutoModelForCausalLM, get_scheduler
|
||||
from transformers.deepspeed import HfDeepSpeedConfig
|
||||
from transformers.testing_utils import mockenv_context
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import is_torch_bf16_available
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
||||
T5_SMALL = "t5-small"
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
GPT2_TINY = "sshleifer/tiny-gpt2"
|
||||
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
|
||||
CUSTOM_OPTIMIZER = "custom_optimizer"
|
||||
CUSTOM_SCHEDULER = "custom_scheduler"
|
||||
DS_OPTIMIZER = "deepspeed_optimizer"
|
||||
DS_SCHEDULER = "deepspeed_scheduler"
|
||||
|
||||
stages = [ZERO2, ZERO3]
|
||||
optims = [CUSTOM_OPTIMIZER, DS_OPTIMIZER]
|
||||
schedulers = [CUSTOM_SCHEDULER, DS_SCHEDULER]
|
||||
if is_torch_bf16_available():
|
||||
dtypes = [FP16, BF16]
|
||||
else:
|
||||
dtypes = [FP16]
|
||||
|
||||
|
||||
def parameterized_custom_name_func(func, param_num, param):
|
||||
# customize the test name generator function as we want both params to appear in the sub-test
|
||||
# name, as by default it shows only the first param
|
||||
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
|
||||
return f"{func.__name__}_{param_based_name}"
|
||||
|
||||
|
||||
# Cartesian-product of zero stages with models to test
|
||||
params = list(itertools.product(stages, dtypes))
|
||||
optim_scheduler_params = list(itertools.product(optims, schedulers))
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
@require_cuda
|
||||
class DeepSpeedConfigIntegration(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self._test_file_path = inspect.getfile(self.__class__)
|
||||
path = Path(self._test_file_path).resolve()
|
||||
self.test_file_dir_str = str(path.parents[0])
|
||||
|
||||
self.ds_config_file = dict(
|
||||
zero2=f"{self.test_file_dir_str}/ds_config_zero2.json",
|
||||
zero3=f"{self.test_file_dir_str}/ds_config_zero3.json",
|
||||
)
|
||||
|
||||
# use self.get_config_dict(stage) to use these to ensure the original is not modified
|
||||
with io.open(self.ds_config_file[ZERO2], "r", encoding="utf-8") as f:
|
||||
config_zero2 = json.load(f)
|
||||
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
|
||||
config_zero3 = json.load(f)
|
||||
# The following setting slows things down, so don't enable it by default unless needed by a test.
|
||||
# It's in the file as a demo for users since we want everything to work out of the box even if slower.
|
||||
config_zero3["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = False
|
||||
|
||||
self.ds_config_dict = dict(zero2=config_zero2, zero3=config_zero3)
|
||||
|
||||
self.dist_env = dict(
|
||||
USE_DEEPSPEED="true",
|
||||
MASTER_ADDR="localhost",
|
||||
MASTER_PORT="10999",
|
||||
RANK="0",
|
||||
LOCAL_RANK="0",
|
||||
WORLD_SIZE="1",
|
||||
)
|
||||
|
||||
def get_config_dict(self, stage):
|
||||
# As some tests modify the dict, always make a copy
|
||||
return deepcopy(self.ds_config_dict[stage])
|
||||
|
||||
@parameterized.expand(stages, name_func=parameterized_custom_name_func)
|
||||
def test_deepspeed_plugin(self, stage):
|
||||
|
||||
# Test zero3_init_flag will be set to False when ZeRO stage != 3
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=True,
|
||||
zero3_init_flag=True,
|
||||
)
|
||||
self.assertFalse(deepspeed_plugin.zero3_init_flag)
|
||||
deepspeed_plugin.deepspeed_config = None
|
||||
|
||||
# Test zero3_init_flag will be set to True only when ZeRO stage == 3
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=True,
|
||||
zero3_init_flag=True,
|
||||
)
|
||||
self.assertTrue(deepspeed_plugin.zero3_init_flag)
|
||||
deepspeed_plugin.deepspeed_config = None
|
||||
|
||||
# Test config files are loaded correctly
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=self.ds_config_file[stage], zero3_init_flag=True)
|
||||
if stage == ZERO2:
|
||||
self.assertFalse(deepspeed_plugin.zero3_init_flag)
|
||||
elif stage == ZERO3:
|
||||
self.assertTrue(deepspeed_plugin.zero3_init_flag)
|
||||
deepspeed_plugin.deepspeed_config = None
|
||||
|
||||
# Test `gradient_accumulation_steps` is set to 1 if unavailable in config file
|
||||
with tempfile.TemporaryDirectory() as dirpath:
|
||||
ds_config = self.get_config_dict(stage)
|
||||
del ds_config["gradient_accumulation_steps"]
|
||||
with open(os.path.join(dirpath, "ds_config.json"), "w") as out_file:
|
||||
json.dump(ds_config, out_file)
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=os.path.join(dirpath, "ds_config.json"))
|
||||
self.assertEqual(deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"], 1)
|
||||
deepspeed_plugin.deepspeed_config = None
|
||||
|
||||
# Test `ValueError` is raised if `zero_optimization` is unavailable in config file
|
||||
with tempfile.TemporaryDirectory() as dirpath:
|
||||
ds_config = self.get_config_dict(stage)
|
||||
del ds_config["zero_optimization"]
|
||||
with open(os.path.join(dirpath, "ds_config.json"), "w") as out_file:
|
||||
json.dump(ds_config, out_file)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=os.path.join(dirpath, "ds_config.json"))
|
||||
self.assertTrue(
|
||||
"Please specify the ZeRO optimization config in the DeepSpeed config file." in str(cm.exception)
|
||||
)
|
||||
deepspeed_plugin.deepspeed_config = None
|
||||
|
||||
# Test `deepspeed_config_process`
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=self.ds_config_file[stage])
|
||||
kwargs = {
|
||||
"fp16.enabled": True,
|
||||
"bf16.enabled": False,
|
||||
"optimizer.params.lr": 5e-5,
|
||||
"optimizer.params.weight_decay": 0.0,
|
||||
"scheduler.params.warmup_min_lr": 0.0,
|
||||
"scheduler.params.warmup_max_lr": 5e-5,
|
||||
"scheduler.params.warmup_num_steps": 0,
|
||||
"train_micro_batch_size_per_gpu": 16,
|
||||
"gradient_clipping": 1.0,
|
||||
"train_batch_size": 16,
|
||||
"zero_optimization.reduce_bucket_size": 5e5,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 5e5,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 5e5,
|
||||
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
|
||||
}
|
||||
deepspeed_plugin.deepspeed_config_process(**kwargs)
|
||||
for ds_key_long, value in kwargs.items():
|
||||
config, ds_key = deepspeed_plugin.find_config_node(ds_key_long)
|
||||
if config.get(ds_key) is not None:
|
||||
self.assertEqual(config.get(ds_key), value)
|
||||
|
||||
# Test mismatches
|
||||
mismatches = {
|
||||
"optimizer.params.lr": 1e-5,
|
||||
"optimizer.params.weight_decay": 1e-5,
|
||||
"gradient_accumulation_steps": 2,
|
||||
}
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
new_kwargs = deepcopy(kwargs)
|
||||
new_kwargs.update(mismatches)
|
||||
deepspeed_plugin.deepspeed_config_process(**new_kwargs)
|
||||
for key in mismatches.keys():
|
||||
self.assertTrue(
|
||||
key in str(cm.exception),
|
||||
f"{key} is not in the exception message:\n{cm.exception}",
|
||||
)
|
||||
|
||||
# Test `ValueError` is raised if some config file fields with `auto` value is missing in `kwargs`
|
||||
deepspeed_plugin.deepspeed_config["optimizer"]["params"]["lr"] = "auto"
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
del kwargs["optimizer.params.lr"]
|
||||
deepspeed_plugin.deepspeed_config_process(**kwargs)
|
||||
self.assertTrue("`optimizer.params.lr` not found in kwargs." in str(cm.exception))
|
||||
|
||||
@parameterized.expand([FP16, BF16], name_func=parameterized_custom_name_func)
|
||||
def test_accelerate_state_deepspeed(self, dtype):
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=ZERO2,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=True,
|
||||
zero3_init_flag=True,
|
||||
)
|
||||
with mockenv_context(**self.dist_env):
|
||||
state = AcceleratorState(mixed_precision=dtype, deepspeed_plugin=deepspeed_plugin, _from_accelerator=True)
|
||||
self.assertTrue(state.deepspeed_plugin.deepspeed_config[dtype]["enabled"])
|
||||
state.initialized = False
|
||||
|
||||
def test_init_zero3(self):
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=True,
|
||||
zero3_init_flag=True,
|
||||
)
|
||||
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
|
||||
self.assertTrue("dschf" in accelerator.__dict__)
|
||||
self.assertTrue(type(accelerator.dschf) == HfDeepSpeedConfig)
|
||||
|
||||
@parameterized.expand(optim_scheduler_params, name_func=parameterized_custom_name_func)
|
||||
def test_prepare_deepspeed(self, optim_type, scheduler_type):
|
||||
# 1. Testing with one of the ZeRO Stages is enough to test the `_prepare_deepspeed` function.
|
||||
# Here we test using ZeRO Stage 2 with FP16 enabled.
|
||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||
|
||||
kwargs = {
|
||||
"fp16.enabled": True,
|
||||
"bf16.enabled": False,
|
||||
"optimizer.params.lr": 5e-5,
|
||||
"optimizer.params.weight_decay": 0.0,
|
||||
"scheduler.params.warmup_min_lr": 0.0,
|
||||
"scheduler.params.warmup_max_lr": 5e-5,
|
||||
"scheduler.params.warmup_num_steps": 0,
|
||||
"train_micro_batch_size_per_gpu": 16,
|
||||
"gradient_clipping": 1.0,
|
||||
"train_batch_size": 16,
|
||||
"zero_optimization.reduce_bucket_size": 5e5,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 5e5,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 5e5,
|
||||
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
|
||||
}
|
||||
|
||||
if optim_type == CUSTOM_OPTIMIZER and scheduler_type == CUSTOM_SCHEDULER:
|
||||
# Test custom optimizer + custom scheduler
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=False,
|
||||
zero3_init_flag=False,
|
||||
)
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
|
||||
self.assertEqual(accelerator.state.deepspeed_plugin.config_file, "none")
|
||||
|
||||
train_set = RegressionDataset(length=80)
|
||||
eval_set = RegressionDataset(length=20)
|
||||
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
|
||||
eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
lr_scheduler = get_scheduler(
|
||||
name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=1000,
|
||||
)
|
||||
dummy_optimizer = DummyOptim(params=model.parameters())
|
||||
dummy_lr_scheduler = DummyScheduler(dummy_optimizer)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
self.assertTrue(
|
||||
"You cannot create a `DummyOptim` without specifying an optimizer in the config file."
|
||||
in str(cm.exception)
|
||||
)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
self.assertTrue(
|
||||
"You cannot create a `DummyScheduler` without specifying a scheduler in the config file."
|
||||
in str(cm.exception)
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
||||
self.assertTrue(
|
||||
"You must specify a training or evaluation dataloader in `accelerate.prepare()` when using DeepSpeed."
|
||||
in str(cm.exception)
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
self.assertTrue(accelerator.deepspeed_config["zero_allow_untested_optimizer"])
|
||||
self.assertTrue(accelerator.deepspeed_config["train_batch_size"], 16)
|
||||
self.assertEqual(type(model), DeepSpeedEngine)
|
||||
self.assertEqual(type(optimizer), DeepSpeedOptimizerWrapper)
|
||||
self.assertEqual(type(lr_scheduler), AcceleratedScheduler)
|
||||
self.assertEqual(type(accelerator.deepspeed_engine_wrapped), DeepSpeedEngineWrapper)
|
||||
|
||||
elif optim_type == DS_OPTIMIZER and scheduler_type == DS_SCHEDULER:
|
||||
# Test DeepSpeed optimizer + DeepSpeed scheduler
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=self.ds_config_file[ZERO2])
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
|
||||
train_set = RegressionDataset(length=80)
|
||||
eval_set = RegressionDataset(length=20)
|
||||
train_dataloader = DataLoader(train_set, batch_size=10, shuffle=True)
|
||||
eval_dataloader = DataLoader(eval_set, batch_size=5, shuffle=False)
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
lr_scheduler = get_scheduler(
|
||||
name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=1000,
|
||||
)
|
||||
dummy_optimizer = DummyOptim(params=model.parameters())
|
||||
dummy_lr_scheduler = DummyScheduler(dummy_optimizer)
|
||||
kwargs["train_batch_size"] = (
|
||||
kwargs["train_micro_batch_size_per_gpu"]
|
||||
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
|
||||
* accelerator.num_processes
|
||||
)
|
||||
accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
self.assertTrue(
|
||||
"You cannot specify an optimizer in the config file and in the code at the same time"
|
||||
in str(cm.exception)
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
self.assertTrue(
|
||||
"You cannot specify a scheduler in the config file and in the code at the same time"
|
||||
in str(cm.exception)
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
self.assertTrue(
|
||||
"You cannot specify a scheduler in the config file and in the code at the same time"
|
||||
in str(cm.exception)
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
self.assertTrue(type(model) == DeepSpeedEngine)
|
||||
self.assertTrue(type(optimizer) == DeepSpeedOptimizerWrapper)
|
||||
self.assertTrue(type(lr_scheduler) == DeepSpeedSchedulerWrapper)
|
||||
self.assertTrue(type(accelerator.deepspeed_engine_wrapped) == DeepSpeedEngineWrapper)
|
||||
|
||||
elif optim_type == CUSTOM_OPTIMIZER and scheduler_type == DS_SCHEDULER:
|
||||
# Test custom optimizer + DeepSpeed scheduler
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=self.ds_config_file[ZERO2])
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
|
||||
train_set = RegressionDataset(length=80)
|
||||
eval_set = RegressionDataset(length=20)
|
||||
train_dataloader = DataLoader(train_set, batch_size=10, shuffle=True)
|
||||
eval_dataloader = DataLoader(eval_set, batch_size=5, shuffle=False)
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
lr_scheduler = get_scheduler(
|
||||
name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=1000,
|
||||
)
|
||||
dummy_optimizer = DummyOptim(params=model.parameters())
|
||||
dummy_lr_scheduler = DummyScheduler(dummy_optimizer)
|
||||
kwargs["train_batch_size"] = (
|
||||
kwargs["train_micro_batch_size_per_gpu"]
|
||||
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
|
||||
* accelerator.num_processes
|
||||
)
|
||||
accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)
|
||||
del accelerator.state.deepspeed_plugin.deepspeed_config["optimizer"]
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
self.assertTrue(type(model) == DeepSpeedEngine)
|
||||
self.assertTrue(type(optimizer) == DeepSpeedOptimizerWrapper)
|
||||
self.assertTrue(type(lr_scheduler) == DeepSpeedSchedulerWrapper)
|
||||
self.assertTrue(type(accelerator.deepspeed_engine_wrapped) == DeepSpeedEngineWrapper)
|
||||
elif optim_type == DS_OPTIMIZER and scheduler_type == CUSTOM_SCHEDULER:
|
||||
# Test deepspeed optimizer + custom scheduler
|
||||
deepspeed_plugin = DeepSpeedPlugin(config_file=self.ds_config_file[ZERO2])
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
|
||||
train_set = RegressionDataset(length=80)
|
||||
eval_set = RegressionDataset(length=20)
|
||||
train_dataloader = DataLoader(train_set, batch_size=10, shuffle=True)
|
||||
eval_dataloader = DataLoader(eval_set, batch_size=5, shuffle=False)
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
lr_scheduler = get_scheduler(
|
||||
name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=1000,
|
||||
)
|
||||
dummy_optimizer = DummyOptim(params=model.parameters())
|
||||
dummy_lr_scheduler = DummyScheduler(dummy_optimizer)
|
||||
kwargs["train_batch_size"] = (
|
||||
kwargs["train_micro_batch_size_per_gpu"]
|
||||
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
|
||||
* accelerator.num_processes
|
||||
)
|
||||
accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)
|
||||
del accelerator.state.deepspeed_plugin.deepspeed_config["scheduler"]
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
self.assertTrue(
|
||||
"You can only specify `accelerate.utils.DummyScheduler` in the code when using `accelerate.utils.DummyOptim`."
|
||||
in str(cm.exception)
|
||||
)
|
||||
accelerator.state.initialized = False
|
||||
|
||||
def test_save_checkpoints(self):
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
config_file=self.ds_config_file[ZERO3],
|
||||
zero3_init_flag=True,
|
||||
)
|
||||
del deepspeed_plugin.deepspeed_config["bf16"]
|
||||
kwargs = {
|
||||
"fp16.enabled": True,
|
||||
"bf16.enabled": False,
|
||||
"optimizer.params.lr": 5e-5,
|
||||
"optimizer.params.weight_decay": 0.0,
|
||||
"scheduler.params.warmup_min_lr": 0.0,
|
||||
"scheduler.params.warmup_max_lr": 5e-5,
|
||||
"scheduler.params.warmup_num_steps": 0,
|
||||
"train_micro_batch_size_per_gpu": 16,
|
||||
"gradient_clipping": 1.0,
|
||||
"train_batch_size": 16,
|
||||
"zero_optimization.reduce_bucket_size": 5e5,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 5e5,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 5e5,
|
||||
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
|
||||
}
|
||||
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
|
||||
kwargs["train_batch_size"] = (
|
||||
kwargs["train_micro_batch_size_per_gpu"]
|
||||
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
|
||||
* accelerator.num_processes
|
||||
)
|
||||
accelerator.state.deepspeed_plugin.deepspeed_config_process(**kwargs)
|
||||
|
||||
train_set = RegressionDataset(length=80)
|
||||
eval_set = RegressionDataset(length=20)
|
||||
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
|
||||
eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
dummy_optimizer = DummyOptim(params=model.parameters())
|
||||
dummy_lr_scheduler = DummyScheduler(dummy_optimizer)
|
||||
|
||||
model, _, train_dataloader, eval_dataloader, _ = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
accelerator.get_state_dict(model)
|
||||
msg = (
|
||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
||||
)
|
||||
self.assertTrue(msg in str(cm.exception))
|
||||
accelerator.state.initialized = False
|
||||
|
||||
def test_autofill_dsconfig(self):
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
config_file=self.ds_config_file[ZERO3],
|
||||
zero3_init_flag=True,
|
||||
)
|
||||
del deepspeed_plugin.deepspeed_config["bf16"]
|
||||
del deepspeed_plugin.deepspeed_config["fp16"]
|
||||
|
||||
with mockenv_context(**self.dist_env):
|
||||
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
|
||||
train_set = RegressionDataset(length=80)
|
||||
eval_set = RegressionDataset(length=20)
|
||||
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
|
||||
eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
dummy_optimizer = DummyOptim(params=model.parameters(), lr=5e-5, weight_decay=1e-4)
|
||||
dummy_lr_scheduler = DummyScheduler(dummy_optimizer, warmup_num_steps=10, total_num_steps=1000)
|
||||
hidden_size = model.config.hidden_size
|
||||
model, _, train_dataloader, eval_dataloader, _ = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
self.assertEqual(accelerator.deepspeed_config["train_micro_batch_size_per_gpu"], 16)
|
||||
self.assertEqual(accelerator.deepspeed_config["train_batch_size"], 16)
|
||||
|
||||
self.assertEqual(accelerator.deepspeed_config["optimizer"]["params"]["lr"], 5e-5)
|
||||
self.assertEqual(accelerator.deepspeed_config["optimizer"]["params"]["weight_decay"], 1e-4)
|
||||
|
||||
self.assertEqual(accelerator.deepspeed_config["scheduler"]["params"]["warmup_min_lr"], 0.0)
|
||||
self.assertEqual(accelerator.deepspeed_config["scheduler"]["params"]["warmup_max_lr"], 5e-5)
|
||||
self.assertEqual(accelerator.deepspeed_config["scheduler"]["params"]["warmup_num_steps"], 10)
|
||||
|
||||
self.assertEqual(accelerator.deepspeed_config["gradient_clipping"], 1.0)
|
||||
self.assertEqual(
|
||||
accelerator.deepspeed_config["zero_optimization"]["reduce_bucket_size"], hidden_size * hidden_size
|
||||
)
|
||||
self.assertEqual(
|
||||
accelerator.deepspeed_config["zero_optimization"]["stage3_prefetch_bucket_size"],
|
||||
0.9 * hidden_size * hidden_size,
|
||||
)
|
||||
self.assertEqual(
|
||||
accelerator.deepspeed_config["zero_optimization"]["stage3_param_persistence_threshold"],
|
||||
10 * hidden_size,
|
||||
)
|
||||
self.assertFalse(
|
||||
accelerator.deepspeed_config["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"]
|
||||
)
|
||||
accelerator.state.initialized = False
|
@ -32,7 +32,13 @@ from accelerate.utils import write_basic_config
|
||||
# Should mock `{script_name}.get_dataloaders` via:
|
||||
# @mock.patch("{script_name}.get_dataloaders", mocked_dataloaders)
|
||||
|
||||
EXCLUDE_EXAMPLES = ["cross_validation.py", "multi_process_metrics.py", "memory.py", "fsdp_with_peak_mem_tracking.py"]
|
||||
EXCLUDE_EXAMPLES = [
|
||||
"cross_validation.py",
|
||||
"multi_process_metrics.py",
|
||||
"memory.py",
|
||||
"fsdp_with_peak_mem_tracking.py",
|
||||
"deepspeed_with_config_support.py",
|
||||
]
|
||||
|
||||
|
||||
class ExampleDifferenceTests(unittest.TestCase):
|
||||
|
Reference in New Issue
Block a user