mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Enable FSDP & Deepspeed + FP8 (#2983)
* Working version rebased from main * kwargs * Clean * Fix more nits * Fin * Delay autocast flag * Enable FP8 autocast during eval only if specified * Fin * Rm comment * All done * Zero3 works! * Let the wrapper come off during unwrap_model * Add import check * Migrate all to benchmarks folder and make TE import check work * Add readme * Add README to benchmarks folder * Update CLI to now include fp8 args * Add test config for 0_34 * Finish adding to config yaml * Write docs * Expound docs w/ FP8 * Add to toctree
This commit is contained in:
@ -1,46 +1,5 @@
|
||||
# Big model inference benchmarks
|
||||
# Benchmarks
|
||||
|
||||
Running inference with Accelerate on big models.
|
||||
The folders below contain suites to test various functionalities in Accelerate.
|
||||
|
||||
## Setup
|
||||
|
||||
These benchmarks use the `transformers` library:
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
To reproduce or test a new setup, run
|
||||
|
||||
```py
|
||||
python inference_acc.py model_name
|
||||
```
|
||||
|
||||
This script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.
|
||||
|
||||
To force a different `torch_dtype` than the one in the config: `--torch_dtype xxx`.
|
||||
|
||||
If you get an error linked to disk offload, you need to add the option `--disk-offload`
|
||||
|
||||
## Results
|
||||
|
||||
On a setup with two Titan RTXs (24GB of RAM) and 32GB of RAM, we get the following benchmarks (T0pp does not run in float16, which is why it's not included).
|
||||
|
||||
| Model | Model load time | Generation time | dtype | GPU 0 use | GPU 1 use | CPU use | Disk offload |
|
||||
|:-----:|:---------------:|:---------------:|:-----:|:---------:|:---------:|:-------:|:------------:|
|
||||
| GPT-J-6B | 8.7s | 0.05s per token | float16 | 11.7GB | 0GB | 0GB | no |
|
||||
| GPT-J-6B | 12.4s | 0.06s per token | float32 | 21.9GB | 1.5GB | 0GB | no |
|
||||
| GPT-Neo-X-20B | 30.9s | 0.08s per token | float16 | 21.5GB | 18GB | 0GB | no |
|
||||
| GPT-Neo-X-20B | 78.2s | 10.72s per token | float32 | 20.3GB | 22.7 GB | 24.4GB | yes |
|
||||
| T0pp (11B) | 29.4s | 0.05s per token | float32 | 21.1GB | 21.3GB | 0GB | no |
|
||||
| OPT-30B | 34.5s | 2.37s per token | float16 | 20.7GB | 22.3GB | 14.1GB | no |
|
||||
| OPT-30B | 112.3s | 33.9s per token | float32 | 20.2GB | 21.2GB | 23.5GB | yes |
|
||||
|
||||
Note on the results:
|
||||
- using two GPUs instead of one does not slow down generation
|
||||
- using CPU offload slows down a bit (see OPT-30b)
|
||||
- using disk offload slows down a lot (need to implement prefetching)
|
||||
|
||||
You will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:
|
||||
- peak GPU memory is exactly the size of the model put on a given GPU
|
||||
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.
|
||||
See their relevant README.md's for more information.
|
||||
|
46
benchmarks/big_model_inference/README.md
Normal file
46
benchmarks/big_model_inference/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# Big model inference benchmarks
|
||||
|
||||
Running inference with Accelerate on big models.
|
||||
|
||||
## Setup
|
||||
|
||||
These benchmarks use the `transformers` library:
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
To reproduce or test a new setup, run
|
||||
|
||||
```py
|
||||
python inference_acc.py model_name
|
||||
```
|
||||
|
||||
This script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.
|
||||
|
||||
To force a different `torch_dtype` than the one in the config: `--torch_dtype xxx`.
|
||||
|
||||
If you get an error linked to disk offload, you need to add the option `--disk-offload`
|
||||
|
||||
## Results
|
||||
|
||||
On a setup with two Titan RTXs (24GB of RAM) and 32GB of RAM, we get the following benchmarks (T0pp does not run in float16, which is why it's not included).
|
||||
|
||||
| Model | Model load time | Generation time | dtype | GPU 0 use | GPU 1 use | CPU use | Disk offload |
|
||||
|:-----:|:---------------:|:---------------:|:-----:|:---------:|:---------:|:-------:|:------------:|
|
||||
| GPT-J-6B | 8.7s | 0.05s per token | float16 | 11.7GB | 0GB | 0GB | no |
|
||||
| GPT-J-6B | 12.4s | 0.06s per token | float32 | 21.9GB | 1.5GB | 0GB | no |
|
||||
| GPT-Neo-X-20B | 30.9s | 0.08s per token | float16 | 21.5GB | 18GB | 0GB | no |
|
||||
| GPT-Neo-X-20B | 78.2s | 10.72s per token | float32 | 20.3GB | 22.7 GB | 24.4GB | yes |
|
||||
| T0pp (11B) | 29.4s | 0.05s per token | float32 | 21.1GB | 21.3GB | 0GB | no |
|
||||
| OPT-30B | 34.5s | 2.37s per token | float16 | 20.7GB | 22.3GB | 14.1GB | no |
|
||||
| OPT-30B | 112.3s | 33.9s per token | float32 | 20.2GB | 21.2GB | 23.5GB | yes |
|
||||
|
||||
Note on the results:
|
||||
- using two GPUs instead of one does not slow down generation
|
||||
- using CPU offload slows down a bit (see OPT-30b)
|
||||
- using disk offload slows down a lot (need to implement prefetching)
|
||||
|
||||
You will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:
|
||||
- peak GPU memory is exactly the size of the model put on a given GPU
|
||||
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.
|
12
benchmarks/fp8/Dockerfile
Normal file
12
benchmarks/fp8/Dockerfile
Normal file
@ -0,0 +1,12 @@
|
||||
FROM nvcr.io/nvidia/pytorch:24.07-py3
|
||||
|
||||
RUN pip install transformers evaluate datasets
|
||||
RUN git clone https://github.com/huggingface/accelerate.git
|
||||
|
||||
RUN cd accelerate && \
|
||||
pip install -e . && \
|
||||
cd benchmarks/fp8
|
||||
|
||||
RUN /bin/bash
|
||||
|
||||
|
30
benchmarks/fp8/README.md
Normal file
30
benchmarks/fp8/README.md
Normal file
@ -0,0 +1,30 @@
|
||||
# FP8 Benchmarks
|
||||
|
||||
Comparing and running [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) FP8 with accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
This repo provides scripts which compare native TransformerEngine model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:
|
||||
|
||||
* Single GPU training (`non_distributed.py`)
|
||||
* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
|
||||
* Fully Sharded Data Parallelism (`fsdp.py`)
|
||||
* DeepSpeed ZeRO 1-3 (`deepspeed.py`)
|
||||
|
||||
To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `TransformerEngine` manually.
|
||||
|
||||
## Running:
|
||||
|
||||
You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.
|
||||
|
||||
For single GPU, run it via `python`:
|
||||
|
||||
```bash
|
||||
python non_distributed.py
|
||||
```
|
||||
|
||||
For the rest, run it via `accelerate launch`:
|
||||
|
||||
```bash
|
||||
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
|
||||
```
|
143
benchmarks/fp8/ddp.py
Normal file
143
benchmarks/fp8/ddp.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
|
||||
|
||||
This particular script verifies this for DDP training.
|
||||
"""
|
||||
import evaluate
|
||||
import torch
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
import transformer_engine.pytorch as te
|
||||
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformer_engine.common.recipe import DelayedScaling
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import FP8RecipeKwargs, set_seed
|
||||
from accelerate.utils.transformer_engine import convert_model
|
||||
|
||||
|
||||
MODEL_NAME = "bert-base-cased"
|
||||
METRIC = evaluate.load("glue", "mrpc")
|
||||
|
||||
|
||||
def train_baseline():
|
||||
set_seed(42)
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
|
||||
accelerator = Accelerator()
|
||||
device = accelerator.device
|
||||
model.to(device)
|
||||
|
||||
# Convert the model to TE
|
||||
old_named_params = get_named_parameters(model)
|
||||
|
||||
with torch.no_grad():
|
||||
convert_model(model)
|
||||
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
|
||||
|
||||
new_named_params = get_named_parameters(model)
|
||||
|
||||
# Convert the model to DDP
|
||||
device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
|
||||
model = DDP(model, device_ids=device_ids, output_device=output_device)
|
||||
|
||||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p] for p in param_group["params"]]
|
||||
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.train()
|
||||
|
||||
for _ in range(2):
|
||||
for batch in train_dataloader:
|
||||
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
batch = batch.to(device)
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results
|
||||
|
||||
|
||||
def train_integration():
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
|
||||
AcceleratorState()._reset_state(True)
|
||||
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
|
||||
set_seed(42)
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
|
||||
MODEL_NAME, accelerator=accelerator
|
||||
)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.train()
|
||||
|
||||
for _ in range(2):
|
||||
for batch in train_dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
baseline_not_trained, baseline_trained = train_baseline()
|
||||
accelerator_not_trained, accelerator_trained = train_integration()
|
||||
|
||||
assert (
|
||||
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
|
||||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
|
||||
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
|
||||
assert (
|
||||
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
|
||||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_trained["f1"] == accelerator_trained["f1"]
|
||||
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
|
||||
|
||||
torch.distributed.destroy_process_group()
|
189
benchmarks/fp8/distrib_deepspeed.py
Normal file
189
benchmarks/fp8/distrib_deepspeed.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
|
||||
|
||||
This particular script verifies this for DDP training.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import deepspeed
|
||||
import evaluate
|
||||
import torch
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
import transformer_engine.pytorch as te
|
||||
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
|
||||
from transformer_engine.common.recipe import DelayedScaling
|
||||
|
||||
from accelerate import Accelerator, DeepSpeedPlugin
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import FP8RecipeKwargs, set_seed
|
||||
from accelerate.utils.transformer_engine import convert_model
|
||||
|
||||
|
||||
MODEL_NAME = "bert-base-cased"
|
||||
METRIC = evaluate.load("glue", "mrpc")
|
||||
|
||||
|
||||
def train_baseline(zero_stage: int = 1):
|
||||
# This forces transformers to think Zero-3 Init should be used
|
||||
with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
|
||||
mock.return_value = zero_stage == 3
|
||||
set_seed(42)
|
||||
|
||||
accelerator = Accelerator()
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
|
||||
MODEL_NAME, accelerator=accelerator
|
||||
)
|
||||
|
||||
# Convert the model to TE
|
||||
old_named_params = get_named_parameters(model)
|
||||
|
||||
with torch.no_grad():
|
||||
convert_model(model)
|
||||
new_named_params = get_named_parameters(model)
|
||||
|
||||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p] for p in param_group["params"]]
|
||||
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
|
||||
|
||||
import numpy as np
|
||||
|
||||
config = {
|
||||
"train_batch_size": 32,
|
||||
"train_micro_batch_size_per_gpu": 16,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"offload_optimizer": {"device": "none", "nvme_path": None},
|
||||
"offload_param": {"device": "none", "nvme_path": None},
|
||||
"stage3_gather_16bit_weights_on_model_save": False,
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"steps_per_print": np.inf,
|
||||
"bf16": {"enabled": True},
|
||||
"fp16": {"enabled": False},
|
||||
"zero_allow_untested_optimizer": True,
|
||||
}
|
||||
|
||||
(
|
||||
model,
|
||||
optimizer,
|
||||
_,
|
||||
_,
|
||||
) = deepspeed.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
config_params=config,
|
||||
)
|
||||
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.train()
|
||||
|
||||
model_outputs = []
|
||||
data = []
|
||||
|
||||
for _ in range(2):
|
||||
for batch in train_dataloader:
|
||||
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
|
||||
outputs = model(**batch)
|
||||
data.append(batch.to("cpu"))
|
||||
model_outputs.append(outputs.logits.to("cpu"))
|
||||
loss = outputs.loss
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
for _ in range(accelerator.num_processes):
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.destroy()
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results, model_outputs, data
|
||||
|
||||
|
||||
def train_integration(zero_stage: int = 1):
|
||||
set_seed(42)
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
|
||||
AcceleratorState()._reset_state(True)
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
zero_stage=zero_stage,
|
||||
zero3_init_flag=zero_stage == 3,
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin
|
||||
)
|
||||
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16
|
||||
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
|
||||
MODEL_NAME, accelerator=accelerator
|
||||
)
|
||||
|
||||
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.train()
|
||||
model_outputs = []
|
||||
data = []
|
||||
for _ in range(2):
|
||||
for batch in train_dataloader:
|
||||
outputs = model(**batch)
|
||||
data.append(batch.to("cpu"))
|
||||
model_outputs.append(outputs.logits.to("cpu"))
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.destroy()
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results, model_outputs, data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# for zero_stage in [1, 2, 3]:
|
||||
zero_stage = 1
|
||||
baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
|
||||
accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage)
|
||||
assert (
|
||||
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
|
||||
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
|
||||
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
|
||||
assert (
|
||||
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
|
||||
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_trained["f1"] == accelerator_trained["f1"]
|
||||
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
|
||||
|
||||
torch.distributed.destroy_process_group()
|
115
benchmarks/fp8/fp8_utils.py
Normal file
115
benchmarks/fp8/fp8_utils.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def get_dataloaders(model_name: str, batch_size: int = 16):
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
datasets = load_dataset("glue", "mrpc")
|
||||
|
||||
def tokenize_function(examples):
|
||||
# max_length=None => use the model max length (it's actually the default)
|
||||
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
||||
return outputs
|
||||
|
||||
# Apply the method we just defined to all the examples in all the splits of the dataset
|
||||
# starting with the main process first:
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=["idx", "sentence1", "sentence2"],
|
||||
)
|
||||
|
||||
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
||||
# transformers library
|
||||
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
||||
|
||||
def collate_fn(examples):
|
||||
return tokenizer.pad(
|
||||
examples,
|
||||
padding="longest",
|
||||
pad_to_multiple_of=16, # Specific for FP8
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Instantiate dataloaders.
|
||||
train_dataloader = DataLoader(
|
||||
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
|
||||
)
|
||||
eval_dataloader = DataLoader(
|
||||
tokenized_datasets["validation"],
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=16,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):
|
||||
"""
|
||||
Returns a tuple of:
|
||||
- Model
|
||||
- Optimizer
|
||||
- Train dataloader (prepared)
|
||||
- Eval dataloader (prepared)
|
||||
- LR Scheduler
|
||||
Suitable for training on the MRPC dataset
|
||||
"""
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
if accelerator is None:
|
||||
accelerator = Accelerator()
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
|
||||
optimizer = AdamW(model.parameters(), lr=0.0001)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=100,
|
||||
num_training_steps=len(train_dataloader) * 2,
|
||||
)
|
||||
train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
|
||||
return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
|
||||
|
||||
def get_named_parameters(model):
|
||||
"""
|
||||
Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted
|
||||
from parallel)
|
||||
"""
|
||||
from accelerate.utils import extract_model_from_parallel
|
||||
|
||||
model = extract_model_from_parallel(model)
|
||||
return {n: p for n, p in model.named_parameters()}
|
||||
|
||||
|
||||
def evaluate_model(model, dataloader, metric, accelerator=None):
|
||||
"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
|
||||
model.eval()
|
||||
for step, batch in enumerate(dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
if accelerator is not None and accelerator.num_processes > 1:
|
||||
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
||||
metric.add_batch(predictions=predictions, references=references)
|
||||
return metric.compute()
|
160
benchmarks/fp8/fsdp.py
Normal file
160
benchmarks/fp8/fsdp.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
|
||||
|
||||
This particular script verifies this for FSDP training.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import evaluate
|
||||
import torch
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
import transformer_engine.pytorch as te
|
||||
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformer_engine.common.recipe import DelayedScaling
|
||||
from transformers.models.bert import BertLayer
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import FP8RecipeKwargs, set_seed
|
||||
from accelerate.utils.transformer_engine import convert_model
|
||||
|
||||
|
||||
MODEL_NAME = "bert-base-cased"
|
||||
METRIC = evaluate.load("glue", "mrpc")
|
||||
|
||||
FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})
|
||||
|
||||
|
||||
def train_baseline():
|
||||
set_seed(42)
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
|
||||
accelerator = Accelerator()
|
||||
device = accelerator.device
|
||||
model.to(device)
|
||||
|
||||
# Convert the model to TE
|
||||
old_named_params = get_named_parameters(model)
|
||||
|
||||
with torch.no_grad():
|
||||
convert_model(model)
|
||||
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
|
||||
|
||||
new_named_params = get_named_parameters(model)
|
||||
|
||||
# Convert the model to FSDP
|
||||
model = FSDP(
|
||||
model,
|
||||
use_orig_params=True,
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
|
||||
auto_wrap_policy=FSDP_WRAP_POLICY,
|
||||
)
|
||||
|
||||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p] for p in param_group["params"]]
|
||||
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.train()
|
||||
|
||||
for _ in range(2):
|
||||
for batch in train_dataloader:
|
||||
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
batch = batch.to(device)
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results
|
||||
|
||||
|
||||
def train_integration():
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
|
||||
AcceleratorState()._reset_state(True)
|
||||
fsdp_plugin = FSDPPlugin(
|
||||
auto_wrap_policy=FSDP_WRAP_POLICY,
|
||||
use_orig_params=True,
|
||||
mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
|
||||
)
|
||||
accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers)
|
||||
set_seed(42)
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
|
||||
MODEL_NAME, accelerator=accelerator
|
||||
)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
model.train()
|
||||
|
||||
for _ in range(2):
|
||||
for batch in train_dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
|
||||
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
baseline_not_trained, baseline_trained = train_baseline()
|
||||
accelerator_not_trained, accelerator_trained = train_integration()
|
||||
|
||||
assert (
|
||||
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
|
||||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
|
||||
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
|
||||
assert (
|
||||
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
|
||||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_trained["f1"] == accelerator_trained["f1"]
|
||||
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
|
||||
|
||||
torch.distributed.destroy_process_group()
|
131
benchmarks/fp8/non_distributed.py
Normal file
131
benchmarks/fp8/non_distributed.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
|
||||
|
||||
This particular script verifies this for single GPU training.
|
||||
"""
|
||||
import evaluate
|
||||
import torch
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
import transformer_engine.pytorch as te
|
||||
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
|
||||
from transformer_engine.common.recipe import DelayedScaling
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import FP8RecipeKwargs, set_seed
|
||||
from accelerate.utils.transformer_engine import convert_model
|
||||
|
||||
|
||||
MODEL_NAME = "bert-base-cased"
|
||||
METRIC = evaluate.load("glue", "mrpc")
|
||||
|
||||
|
||||
def train_baseline():
|
||||
set_seed(42)
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
|
||||
|
||||
# Convert the model to TE
|
||||
old_named_params = get_named_parameters(model)
|
||||
|
||||
with torch.no_grad():
|
||||
convert_model(model)
|
||||
|
||||
new_named_params = get_named_parameters(model)
|
||||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p] for p in param_group["params"]]
|
||||
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
|
||||
|
||||
model.to("cuda")
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
|
||||
model.train()
|
||||
|
||||
for batch in train_dataloader:
|
||||
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
batch = batch.to("cuda")
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC)
|
||||
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results
|
||||
|
||||
|
||||
def train_integration():
|
||||
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
|
||||
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
|
||||
AcceleratorState()._reset_state(True)
|
||||
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
|
||||
set_seed(42)
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
|
||||
MODEL_NAME, accelerator=accelerator
|
||||
)
|
||||
|
||||
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
||||
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
|
||||
model.train()
|
||||
|
||||
for batch in train_dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC)
|
||||
|
||||
assert (
|
||||
trained_model_results["accuracy"] > base_model_results["accuracy"]
|
||||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
|
||||
assert (
|
||||
trained_model_results["f1"] > base_model_results["f1"]
|
||||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
|
||||
|
||||
return base_model_results, trained_model_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
baseline_not_trained, baseline_trained = train_baseline()
|
||||
accelerator_not_trained, accelerator_trained = train_integration()
|
||||
|
||||
assert (
|
||||
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
|
||||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
|
||||
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
|
||||
assert (
|
||||
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
|
||||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
|
||||
assert (
|
||||
baseline_trained["f1"] == accelerator_trained["f1"]
|
||||
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
|
@ -112,6 +112,8 @@
|
||||
title: Distributed inference with big models
|
||||
- local: package_reference/kwargs
|
||||
title: Kwargs handlers
|
||||
- local: package_reference/fp8
|
||||
title: FP8 Functionality
|
||||
- local: package_reference/utilities
|
||||
title: Utility functions and classes
|
||||
- local: package_reference/megatron_lm
|
||||
|
@ -145,10 +145,11 @@ values. They can also be passed in manually.
|
||||
|
||||
The following arguments are useful for fine-tuning how available hardware should be used
|
||||
|
||||
* `--mixed_precision {no,fp16,bf16}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.
|
||||
* `--mixed_precision {no,fp16,bf16,fp8}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.
|
||||
* `--num_processes NUM_PROCESSES` (`int`) -- The total number of processes to be launched in parallel.
|
||||
* `--num_machines NUM_MACHINES` (`int`) -- The total number of machines used in this training.
|
||||
* `--num_cpu_threads_per_process NUM_CPU_THREADS_PER_PROCESS` (`int`) -- The number of CPU threads per process. Can be tuned for optimal performance.
|
||||
* `--enable_cpu_affinity` (`bool`) -- Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.
|
||||
|
||||
**Training Paradigm Arguments**:
|
||||
|
||||
@ -165,19 +166,26 @@ The following arguments are only useful when `multi_gpu` is passed or multi-gpu
|
||||
|
||||
* `--gpu_ids` (`str`) -- What GPUs (by id) should be used for training on this machine as a comma-seperated list
|
||||
* `--same_network` (`bool`) -- Whether all machines used for multinode training exist on the same local network.
|
||||
* `--machine_rank MACHINE_RANK` (`int`) -- The rank of the machine on which this script is launched.
|
||||
* `--main_process_ip MAIN_PROCESS_IP` (`str`) -- The IP address of the machine of rank 0.
|
||||
* `--main_process_port MAIN_PROCESS_PORT` (`int`) -- The port to use to communicate with the machine of rank 0.
|
||||
* `--rdzv_backend` (`str`) -- The rendezvous method to use, such as "static" or "c10d"
|
||||
* `--machine_rank` (`int`) -- The rank of the machine on which this script is launched.
|
||||
* `--main_process_ip` (`str`) -- The IP address of the machine of rank 0.
|
||||
* `--main_process_port` (`int`) -- The port to use to communicate with the machine of rank 0.
|
||||
* `-t`, `--tee` (`str`) -- Tee std streams into a log file and also to console.
|
||||
* `--log_dir` (`str`) -- Base directory to use for log files when using torchrun/torch.distributed.run as launcher. Use with --tee to redirect std streams info log files.
|
||||
* `--role` (`str`) -- User-defined role for the workers.
|
||||
* `--rdzv_backend` (`str`) -- The rendezvous method to use, such as 'static' (the default) or 'c10d'
|
||||
* `--rdzv_conf` (`str`) -- Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).
|
||||
* `--max_restarts` (`int`) -- Maximum number of worker group restarts before failing.
|
||||
* `--monitor_interval` (`float`) -- Interval, in seconds, to monitor the state of workers.
|
||||
* `--monitor_interval` (`int`) -- Interval, in seconds, to monitor the state of workers.
|
||||
|
||||
**TPU Arguments**:
|
||||
|
||||
The following arguments are only useful when `tpu` is passed or TPU training is configured through `accelerate config`:
|
||||
|
||||
* `--main_training_function MAIN_TRAINING_FUNCTION` (`str`) -- The name of the main function to be executed in your script.
|
||||
* `--tpu_cluster` (`bool`) -- Whether to use a GCP TPU pod for training.
|
||||
* `--tpu_use_sudo` (`bool`) -- Whether to use `sudo` when running the TPU training script in each pod.
|
||||
* `--vm` (`str`) -- List of single Compute VM instance names. If not provided we assume usage of instance groups. For TPU pods.
|
||||
* `--env` (`str`) -- List of environment variables to set on the Compute VM instances. For TPU pods.
|
||||
* `--main_training_function` (`str`) -- The name of the main function to be executed in your script (only for TPU training).
|
||||
* `--downcast_bf16` (`bool`) -- Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.
|
||||
|
||||
**DeepSpeed Arguments**:
|
||||
@ -188,6 +196,7 @@ The following arguments are only useful when `use_deepspeed` is passed or `deeps
|
||||
* `--zero_stage` (`int`) -- DeepSpeed's ZeRO optimization stage.
|
||||
* `--offload_optimizer_device` (`str`) -- Decides where (none|cpu|nvme) to offload optimizer states.
|
||||
* `--offload_param_device` (`str`) -- Decides where (none|cpu|nvme) to offload parameters.
|
||||
* `--offload_optimizer_nvme_path` (`str`) -- Decides Nvme Path to offload optimizer states.
|
||||
* `--gradient_accumulation_steps` (`int`) -- No of gradient_accumulation_steps used in your training script.
|
||||
* `--gradient_clipping` (`float`) -- Gradient clipping value used in your training script.
|
||||
* `--zero3_init_flag` (`str`) -- Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with DeepSpeed ZeRO Stage-3.
|
||||
@ -196,6 +205,7 @@ The following arguments are only useful when `use_deepspeed` is passed or `deeps
|
||||
* `--deepspeed_exclusion_filter` (`str`) -- DeepSpeed exclusion filter string when using mutli-node setup.
|
||||
* `--deepspeed_inclusion_filter` (`str`) -- DeepSpeed inclusion filter string when using mutli-node setup.
|
||||
* `--deepspeed_multinode_launcher` (`str`) -- DeepSpeed multi-node launcher to use.
|
||||
* `--deepspeed_moe_layer_cls_names` (`str`) -- comma-separated list of transformer MoE layer class names (case-sensitive) to wrap, e.g, `MixtralSparseMoeBlock` `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock`
|
||||
|
||||
**Fully Sharded Data Parallelism Arguments**:
|
||||
|
||||
@ -210,8 +220,9 @@ The following arguments are only useful when `use_fsdp` is passed or Fully Shard
|
||||
* `--fsdp_state_dict_type` (`str`) -- FSDP's state dict type.
|
||||
* `--fsdp_forward_prefetch` (`str`) -- FSDP forward prefetch.
|
||||
* `--fsdp_use_orig_params` (`str`) -- If True, allows non-uniform `requires_grad` mixed in a FSDP unit.
|
||||
* `--fsdp_cpu_ram_efficient_loading` (`str`) - If true, only the first process loads the pretrained model checkoint while all other processes have empty weights. When using this, `--fsdp_sync_module_states` needs to True.
|
||||
* `--fsdp_sync_module_states` (`str`) - If true, each individually wrapped FSDP unit will broadcast module parameters from rank 0.
|
||||
* `--fsdp_cpu_ram_efficient_loading` (`str`) -- If true, only the first process loads the pretrained model checkoint while all other processes have empty weights. When using this, `--fsdp_sync_module_states` needs to True.
|
||||
* `--fsdp_sync_module_states` (`str`) -- If true, each individually wrapped FSDP unit will broadcast module parameters from rank 0.
|
||||
* `--fsdp_activation_checkpointing` (`bool`) -- Decides Whether intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder
|
||||
|
||||
**Megatron-LM Arguments**:
|
||||
|
||||
@ -225,6 +236,18 @@ The following arguments are only useful when `use_megatron_lm` is passed or Mega
|
||||
* `--megatron_lm_use_distributed_optimizer` (``) -- Decides Whether (true|false) to use distributed optimizer which shards optimizer state and gradients across Data Parallel (DP) ranks.
|
||||
* `--megatron_lm_gradient_clipping` (``) -- Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable).
|
||||
|
||||
**FP8 Arguments**:
|
||||
|
||||
* `--fp8_backend` (`str`) -- Choose a backend to train with FP8 (`te` or `msamp`)
|
||||
* `--fp8_use_autocast_during_eval` (`bool`) -- Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.
|
||||
* `--fp8_margin` (`int`) -- The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).
|
||||
* `--fp8_interval` (`int`) -- The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).
|
||||
* `--fp8_format` (`str`) -- The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).
|
||||
* `--fp8_amax_history_len` (`int`) -- The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).
|
||||
* `--fp8_amax_compute_algo` (`str`) -- The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).
|
||||
* `--fp8_override_linear_precision` (`Tuple[bool, bool, bool]`) -- Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
|
||||
* `--fp8_opt_level` (`str`) -- What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed)
|
||||
|
||||
**AWS SageMaker Arguments**:
|
||||
|
||||
The following arguments are only useful when training in SageMaker
|
||||
|
28
docs/source/package_reference/fp8.md
Normal file
28
docs/source/package_reference/fp8.md
Normal file
@ -0,0 +1,28 @@
|
||||
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# FP8 Functionality
|
||||
|
||||
Below are functions and classes relative to the underlying FP8 implementation
|
||||
|
||||
[[autodoc]] utils.FP8RecipeKwargs
|
||||
|
||||
[[autodoc]] utils.convert_model
|
||||
|
||||
[[autodoc]] utils.has_transformer_engine_layers
|
||||
|
||||
[[autodoc]] utils.contextual_fp8_autocast
|
||||
|
||||
[[autodoc]] utils.apply_fp8_autowrap
|
@ -39,7 +39,7 @@ from accelerate import Accelerator
|
||||
accelerator = Accelerator(mixed_precision="fp8")
|
||||
```
|
||||
|
||||
By default, if `MS-AMP` is available in your environment, 🤗 Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`]:
|
||||
By default, if `MS-AMP` is available in your environment, 🤗 Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`] or clarify it in your config `yaml`/during `accelerate launch`:
|
||||
|
||||
```{python}
|
||||
from accelerate import Accelerator
|
||||
@ -50,6 +50,19 @@ kwargs = [FP8RecipeKwargs(backend="msamp")]
|
||||
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
|
||||
```
|
||||
|
||||
```{yaml}
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
amax_compute_algorithm: max
|
||||
amax_history_length: 1024
|
||||
backend: TE
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: false
|
||||
use_autocast_during_eval: false
|
||||
```
|
||||
|
||||
## Configuring MS-AMP
|
||||
|
||||
Of the two, `MS-AMP` is traditionally the easier one to configure as there is only a single argument: the optimization level.
|
||||
@ -68,6 +81,17 @@ kwargs = [FP8RecipeKwargs(backend="msamp", optimization_level="O2")]
|
||||
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
|
||||
```
|
||||
|
||||
Or during `accelerate launch` via `--fp8_backend=msamp --fp8_opt_level=O2`
|
||||
|
||||
Similarly this can be set in your `config.yaml`:
|
||||
|
||||
```{yaml}
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: MSAMP
|
||||
opt_level: O2
|
||||
```
|
||||
|
||||
## Configuring TransformersEngine
|
||||
|
||||
TransformersEngine has much more available for customizing how and what FP8 calculations are performed. A full list of supported arguments and what they mean are available in [NVIDIA's documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html), however they are restated as part of [`FP8KwargsHandler`]'s docstring for your convenience.
|
||||
@ -83,6 +107,35 @@ kwargs = [FP8RecipeKwargs(backend="te", ...)]
|
||||
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
|
||||
```
|
||||
|
||||
Or during `accelerate launch` via `--fp8_backend=te ...`. Use `accelerate launch --fp8_backend=te -h` to see relevent arguments.
|
||||
|
||||
Similarly this can be set in your `config.yaml`:
|
||||
|
||||
```{yaml}
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
amax_compute_algorithm: max
|
||||
amax_history_length: 1024
|
||||
backend: TE
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: false
|
||||
use_autocast_during_eval: false
|
||||
```
|
||||
|
||||
## Example Zoo
|
||||
|
||||
We have examples showcasing training with FP8 both with accelerate and its underlying implementation available in the accelerate repo.
|
||||
Currently we support scripts showcasing:
|
||||
|
||||
* Single GPU
|
||||
* Distributed Data Parallelism (Multi-GPU)
|
||||
* Fully Sharded Data Parallelism
|
||||
* DeepSpeed ZeRO 1 through 3
|
||||
|
||||
Find out more [here](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8)
|
||||
|
||||
## Further Reading
|
||||
|
||||
To learn more about training in FP8 please check out the following resources:
|
||||
|
@ -68,6 +68,7 @@ from .utils import (
|
||||
ProjectConfiguration,
|
||||
RNGType,
|
||||
TorchDynamoPlugin,
|
||||
apply_fp8_autowrap,
|
||||
check_os_kernel,
|
||||
clean_state_dict_for_safetensors,
|
||||
compare_versions,
|
||||
@ -78,7 +79,6 @@ from .utils import (
|
||||
gather_object,
|
||||
get_mixed_precision_context_manager,
|
||||
get_pretty_name,
|
||||
has_transformer_engine_layers,
|
||||
is_bf16_available,
|
||||
is_deepspeed_available,
|
||||
is_ipex_available,
|
||||
@ -391,10 +391,15 @@ class Accelerator:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None:
|
||||
self.fp8_recipe_handler = FP8RecipeKwargs()
|
||||
|
||||
self.delayed_fp8_autocast = False
|
||||
if self.fp8_recipe_handler is not None:
|
||||
# We already check if FP8 is available during `self.state`
|
||||
if self.state.mixed_precision != "fp8":
|
||||
if self.state.mixed_precision != "fp8" and (
|
||||
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
|
||||
):
|
||||
raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.")
|
||||
self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in (
|
||||
DistributedType.MULTI_GPU,
|
||||
@ -1290,7 +1295,8 @@ class Accelerator:
|
||||
|
||||
# If we're dealing with device placement, this deals with that by...
|
||||
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
|
||||
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
|
||||
|
||||
if tpu_should_fix_optimizer:
|
||||
# 1. grabbing old model parameters
|
||||
old_named_params = self._get_named_parameters(*args)
|
||||
|
||||
@ -1299,6 +1305,8 @@ class Accelerator:
|
||||
args = self._prepare_ipex_or_xpu(*args)
|
||||
elif self.device.type == "xpu" and is_xpu_available():
|
||||
args = self._prepare_ipex_or_xpu(*args)
|
||||
if self.fp8_recipe_handler is not None and self.fp8_recipe_handler.backend == "TE":
|
||||
args = self._prepare_te(*args)
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
result = self._prepare_deepspeed(*args)
|
||||
elif self.distributed_type == DistributedType.MEGATRON_LM:
|
||||
@ -1312,8 +1320,7 @@ class Accelerator:
|
||||
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
|
||||
)
|
||||
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
|
||||
|
||||
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
|
||||
if tpu_should_fix_optimizer:
|
||||
# 2. grabbing new model parameters
|
||||
new_named_params = self._get_named_parameters(*result)
|
||||
# 3. building a map from the first to the second
|
||||
@ -1384,24 +1391,8 @@ class Accelerator:
|
||||
model.forward = convert_outputs_to_fp32(new_forward)
|
||||
|
||||
# We prepare fp8 after, allowing for bf16 autocast to happen first
|
||||
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
|
||||
# Import here to keep base imports fast
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
from transformer_engine.pytorch import fp8_autocast
|
||||
|
||||
if not has_transformer_engine_layers(model):
|
||||
with torch.no_grad():
|
||||
convert_model(model)
|
||||
model._converted_to_transformer_engine = True
|
||||
|
||||
kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
|
||||
if "fp8_format" in kwargs:
|
||||
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
|
||||
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
||||
# If we are in DDP or FSDP, we delay `autocast` until after FSDP/DDP has been initialized
|
||||
# to make use of the process group
|
||||
if not self.delayed_fp8_autocast:
|
||||
model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward)
|
||||
if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast:
|
||||
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
|
||||
|
||||
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
|
||||
model, "hf_device_map", False
|
||||
@ -1455,6 +1446,7 @@ class Accelerator:
|
||||
if self.ddp_handler is not None:
|
||||
self.ddp_handler.register_comm_hook(model)
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
# We need to fix the optimizer *before* sharding the model
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
|
||||
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
|
||||
@ -1578,9 +1570,7 @@ class Accelerator:
|
||||
model = xmp.MpModelWrapper(model).to(self.device)
|
||||
# Now we can apply the FP8 autocast
|
||||
if self.delayed_fp8_autocast:
|
||||
model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=model.process_group)(
|
||||
model.forward
|
||||
)
|
||||
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
|
||||
# torch.compile should be called last and only if the model isn't already compiled.
|
||||
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
|
||||
if not is_torch_version(">=", "2.0"):
|
||||
@ -1588,6 +1578,42 @@ class Accelerator:
|
||||
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
|
||||
return model
|
||||
|
||||
def _prepare_te(self, *args):
|
||||
if not is_msamp_available():
|
||||
raise ImportError(
|
||||
"`transformer_engine` was not found on your system. Please ensure that `transformer_engine` is installed"
|
||||
)
|
||||
model, optimizer = None, None
|
||||
num_models, num_optimizers = 0, 0
|
||||
result = [obj for obj in args]
|
||||
for obj in result:
|
||||
if isinstance(obj, torch.nn.Module):
|
||||
model = obj
|
||||
num_models += 1
|
||||
elif isinstance(obj, (torch.optim.Optimizer)):
|
||||
optimizer = obj
|
||||
num_optimizers += 1
|
||||
if optimizer is None and model is None:
|
||||
return result
|
||||
elif optimizer is None or model is None:
|
||||
raise ValueError(
|
||||
"You must pass a model and an optimizer together to `accelerate.prepare()` when using TransformerEngine."
|
||||
)
|
||||
elif num_models > 1 or num_optimizers > 1:
|
||||
raise ValueError(
|
||||
f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with TransformerEngine."
|
||||
)
|
||||
old_named_params = self._get_named_parameters(model)
|
||||
with torch.no_grad():
|
||||
convert_model(model)
|
||||
new_named_params = self._get_named_parameters(model)
|
||||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
|
||||
# We need to switch the optimizer params to the new params *after* the model is wrapped in FSDP
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p] for p in param_group["params"]]
|
||||
|
||||
return result
|
||||
|
||||
def _prepare_deepspeed(self, *args):
|
||||
import deepspeed
|
||||
|
||||
@ -1696,6 +1722,9 @@ class Accelerator:
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
# If we are using FP8, we need to apply the autowrap now
|
||||
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
|
||||
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
|
||||
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
|
||||
deepspeed_plugin.set_moe_leaf_modules(model)
|
||||
# deal with config keys that use `auto` value and rely on model's hidden_size
|
||||
|
@ -20,10 +20,13 @@ from ...utils import (
|
||||
ComputeEnvironment,
|
||||
DistributedType,
|
||||
is_deepspeed_available,
|
||||
is_fp8_available,
|
||||
is_mlu_available,
|
||||
is_mps_available,
|
||||
is_msamp_available,
|
||||
is_musa_available,
|
||||
is_npu_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformers_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
@ -42,6 +45,7 @@ from .config_utils import (
|
||||
_ask_options,
|
||||
_convert_distributed_mode,
|
||||
_convert_dynamo_backend,
|
||||
_convert_fp8_backend,
|
||||
_convert_mixed_precision,
|
||||
_convert_yes_no_to_bool,
|
||||
)
|
||||
@ -616,6 +620,7 @@ def get_cluster_input():
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
fp8_config = None
|
||||
if distributed_type == DistributedType.XLA:
|
||||
mixed_precision = "no"
|
||||
main_training_function = _ask_field(
|
||||
@ -697,10 +702,86 @@ def get_cluster_input():
|
||||
mixed_precision = None
|
||||
else:
|
||||
mixed_precision = _ask_options(
|
||||
"Do you wish to use FP16 or BF16 (mixed precision)?",
|
||||
"Do you wish to use mixed precision?",
|
||||
["no", "fp16", "bf16", "fp8"],
|
||||
_convert_mixed_precision,
|
||||
)
|
||||
if mixed_precision == "fp8":
|
||||
if not is_fp8_available():
|
||||
raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
|
||||
fp8_config = {}
|
||||
fp8_config["backend"] = _ask_options(
|
||||
"Which FP8 backend do you want to use?",
|
||||
["te", "msamp"],
|
||||
_convert_fp8_backend,
|
||||
)
|
||||
if fp8_config["backend"] == "TE":
|
||||
if not is_transformer_engine_available():
|
||||
raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
|
||||
fp8_config["use_autocast_during_eval"] = _ask_field(
|
||||
"Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
)
|
||||
fp8_config["margin"] = _ask_field(
|
||||
"What margin should be used for gradient scaling? [0]: ",
|
||||
int,
|
||||
default=0,
|
||||
)
|
||||
fp8_config["interval"] = _ask_field(
|
||||
"What interval should be used for for how often the scaling factor is recomputed? [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
)
|
||||
fp8_config["fp8_format"] = _ask_options(
|
||||
"Which weight format should be used?",
|
||||
["E4M3", "HYBRID"],
|
||||
lambda x: "E4M3" if x == 0 else "HYBRID",
|
||||
default=0,
|
||||
)
|
||||
fp8_config["amax_history_length"] = _ask_field(
|
||||
"What length of history should be used for the amax scaling factor computation? [1024]: ",
|
||||
int,
|
||||
default=1024,
|
||||
)
|
||||
fp8_config["amax_compute_algorithm"] = _ask_options(
|
||||
"Which algorithm should be used for the amax scaling factor computation?",
|
||||
["max", "most_recent"],
|
||||
lambda x: "max" if x == 0 else "most_recent",
|
||||
default=0,
|
||||
)
|
||||
fp8_config["override_linear_precision"] = _ask_field(
|
||||
"Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
)
|
||||
if fp8_config["override_linear_precision"]:
|
||||
fprop = _ask_field(
|
||||
"Should `fprop` be executed in higher precision? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
)
|
||||
dgrad = _ask_field(
|
||||
"Should `dgrad` be executed in higher precision? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
)
|
||||
wgrad = _ask_field(
|
||||
"Should `wgrad` be executed in higher precision? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
)
|
||||
fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
|
||||
|
||||
elif fp8_config["backend"] == "MSAMP":
|
||||
if not is_msamp_available():
|
||||
raise ValueError("MSAMP was selected, but it is not installed on this machine.")
|
||||
fp8_config["optimization_level"] = _ask_options(
|
||||
"Which optimization level should be used?",
|
||||
["O1", "O2"],
|
||||
lambda x: "O1" if x == 0 else "O2",
|
||||
default=1,
|
||||
)
|
||||
|
||||
if use_dynamo and mixed_precision == "no" and not use_cpu:
|
||||
print(
|
||||
@ -724,6 +805,7 @@ def get_cluster_input():
|
||||
main_process_ip=main_process_ip,
|
||||
main_process_port=main_process_port,
|
||||
main_training_function=main_training_function,
|
||||
fp8_config=fp8_config,
|
||||
deepspeed_config=deepspeed_config,
|
||||
fsdp_config=fsdp_config,
|
||||
megatron_lm_config=megatron_lm_config,
|
||||
|
@ -83,11 +83,19 @@ class BaseConfig:
|
||||
def to_dict(self):
|
||||
result = self.__dict__
|
||||
# For serialization, it's best to convert Enums to strings (or their underlying value type).
|
||||
for key, value in result.items():
|
||||
|
||||
def _convert_enums(value):
|
||||
if isinstance(value, Enum):
|
||||
result[key] = value.value
|
||||
if isinstance(value, dict) and not bool(value):
|
||||
result[key] = None
|
||||
return value.value
|
||||
if isinstance(value, dict):
|
||||
if not bool(value):
|
||||
return None
|
||||
for key1, value1 in value.items():
|
||||
value[key1] = _convert_enums(value1)
|
||||
return value
|
||||
|
||||
for key, value in result.items():
|
||||
result[key] = _convert_enums(value)
|
||||
result = {k: v for k, v in result.items() if v is not None}
|
||||
return result
|
||||
|
||||
@ -184,6 +192,8 @@ class ClusterConfig(BaseConfig):
|
||||
main_training_function: str = "main"
|
||||
enable_cpu_affinity: bool = False
|
||||
|
||||
# args for FP8 training
|
||||
fp8_config: dict = None
|
||||
# args for deepspeed_plugin
|
||||
deepspeed_config: dict = None
|
||||
# args for fsdp
|
||||
@ -221,6 +231,8 @@ class ClusterConfig(BaseConfig):
|
||||
self.ipex_config = {}
|
||||
if self.mpirun_config is None:
|
||||
self.mpirun_config = {}
|
||||
if self.fp8_config is None:
|
||||
self.fp8_config = {}
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
|
@ -20,6 +20,7 @@ from ...utils.dataclasses import (
|
||||
ComputeEnvironment,
|
||||
DistributedType,
|
||||
DynamoBackend,
|
||||
FP8BackendType,
|
||||
PrecisionType,
|
||||
SageMakerDistributedType,
|
||||
)
|
||||
@ -90,6 +91,11 @@ def _convert_sagemaker_distributed_mode(value):
|
||||
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
|
||||
|
||||
|
||||
def _convert_fp8_backend(value):
|
||||
value = int(value)
|
||||
return FP8BackendType(["TE", "MSAMP"][value])
|
||||
|
||||
|
||||
def _convert_yes_no_to_bool(value):
|
||||
return {"yes": True, "no": False}[value.lower()]
|
||||
|
||||
|
@ -53,6 +53,7 @@ from accelerate.utils import (
|
||||
prepare_sagemager_args_inputs,
|
||||
prepare_simple_launcher_cmd_env,
|
||||
prepare_tpu,
|
||||
str_to_bool,
|
||||
)
|
||||
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
|
||||
|
||||
@ -74,11 +75,14 @@ options_to_group = {
|
||||
"use_deepspeed": "DeepSpeed Arguments",
|
||||
"use_fsdp": "FSDP Arguments",
|
||||
"use_megatron_lm": "Megatron-LM Arguments",
|
||||
"fp8_backend": "FP8 Arguments",
|
||||
}
|
||||
|
||||
|
||||
def clean_option(option):
|
||||
"Finds all cases of - after the first two characters and changes them to _"
|
||||
if "fp8_backend" in option:
|
||||
option = "--fp8_backend"
|
||||
if option.startswith("--"):
|
||||
return option[2:].replace("-", "_")
|
||||
|
||||
@ -214,7 +218,6 @@ def launch_command_parser(subparsers=None):
|
||||
action="store_true",
|
||||
help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
|
||||
)
|
||||
|
||||
# Dynamo arguments
|
||||
resource_args.add_argument(
|
||||
"--dynamo_backend",
|
||||
@ -642,6 +645,68 @@ def launch_command_parser(subparsers=None):
|
||||
"(useful only when `use_megatron_lm` flag is passed).",
|
||||
)
|
||||
|
||||
# FP8 arguments
|
||||
fp8_args = parser.add_argument_group(
|
||||
"FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_backend",
|
||||
type=str,
|
||||
choices=["te", "msamp"],
|
||||
help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_use_autocast_during_eval",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_margin",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_interval",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_format",
|
||||
type=str,
|
||||
default="E4M3",
|
||||
choices=["E4M3", "HYBRID"],
|
||||
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_amax_history_len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_amax_compute_algo",
|
||||
type=str,
|
||||
default="most_recent",
|
||||
choices=["max", "most_recent"],
|
||||
help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_override_linear_precision",
|
||||
type=lambda x: tuple(map(str_to_bool, x.split(","))),
|
||||
default=(False, False, False),
|
||||
help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-seperated string of booleans (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
"--fp8_opt_level",
|
||||
type=str,
|
||||
default="O2",
|
||||
choices=["O1", "O2"],
|
||||
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
|
||||
)
|
||||
|
||||
# AWS arguments
|
||||
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
|
||||
aws_args.add_argument(
|
||||
|
@ -41,6 +41,7 @@ from .testing import (
|
||||
require_torch_min_version,
|
||||
require_torchvision,
|
||||
require_tpu,
|
||||
require_transformer_engine,
|
||||
require_xpu,
|
||||
skip,
|
||||
slow,
|
||||
|
@ -53,6 +53,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
is_torchvision_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformers_available,
|
||||
is_triton_available,
|
||||
is_wandb_available,
|
||||
@ -404,6 +405,14 @@ def require_import_timer(test_case):
|
||||
return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)
|
||||
|
||||
|
||||
def require_transformer_engine(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers
|
||||
engine isn't installed
|
||||
"""
|
||||
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
|
||||
|
||||
|
||||
_atleast_one_tracker_available = (
|
||||
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
|
||||
)
|
||||
|
@ -149,6 +149,7 @@ from .offload import (
|
||||
)
|
||||
from .operations import (
|
||||
CannotPadNestedTensorWarning,
|
||||
GatheredParameters,
|
||||
broadcast,
|
||||
broadcast_object_list,
|
||||
concatenate,
|
||||
@ -250,4 +251,9 @@ from .other import (
|
||||
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
|
||||
from .torch_xla import install_xla
|
||||
from .tqdm import tqdm
|
||||
from .transformer_engine import convert_model, has_transformer_engine_layers
|
||||
from .transformer_engine import (
|
||||
apply_fp8_autowrap,
|
||||
contextual_fp8_autocast,
|
||||
convert_model,
|
||||
has_transformer_engine_layers,
|
||||
)
|
||||
|
@ -30,8 +30,15 @@ from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY
|
||||
from .environment import str_to_bool
|
||||
from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available
|
||||
from .environment import parse_flag_from_env, str_to_bool
|
||||
from .imports import (
|
||||
is_cuda_available,
|
||||
is_mlu_available,
|
||||
is_msamp_available,
|
||||
is_npu_available,
|
||||
is_transformer_engine_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from .versions import compare_versions
|
||||
|
||||
|
||||
@ -297,8 +304,11 @@ class FP8RecipeKwargs(KwargsHandler):
|
||||
```
|
||||
|
||||
Args:
|
||||
backend (`str`, *optional*, defaults to "msamp"):
|
||||
Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine).
|
||||
backend (`str`, *optional*):
|
||||
Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). If not passed,
|
||||
will use whichever is available in the environment, prioritizing MS-AMP.
|
||||
use_autocast_during_eval (`bool`, *optional*, default to `False`):
|
||||
Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`.
|
||||
margin (`int`, *optional*, default to 0):
|
||||
The margin to use for the gradient scaling.
|
||||
interval (`int`, *optional*, default to 1):
|
||||
@ -323,28 +333,60 @@ class FP8RecipeKwargs(KwargsHandler):
|
||||
available currently).
|
||||
"""
|
||||
|
||||
backend: Backend = "MSAMP"
|
||||
opt_level: OptLevel = "O2"
|
||||
margin: int = 0
|
||||
interval: int = 1
|
||||
fp8_format: FP8Format = "E4M3"
|
||||
amax_history_len: int = 1
|
||||
amax_compute_algo: AmaxComputeAlgorithm = "most_recent"
|
||||
override_linear_precision: Tuple[bool, bool, bool] = (False, False, False)
|
||||
backend: Backend = None
|
||||
use_autocast_during_eval: bool = None
|
||||
opt_level: OptLevel = None
|
||||
margin: int = None
|
||||
interval: int = None
|
||||
fp8_format: FP8Format = None
|
||||
amax_history_len: int = None
|
||||
amax_compute_algo: AmaxComputeAlgorithm = None
|
||||
override_linear_precision: Tuple[bool, bool, bool] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.backend.upper() not in get_args(Backend):
|
||||
raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).")
|
||||
|
||||
env_prefix = "ACCELERATE_FP8_"
|
||||
default_backend = "msamp" if is_msamp_available() else "te"
|
||||
if self.backend is None:
|
||||
self.backend = os.environ.get(env_prefix + "BACKEND", default_backend)
|
||||
self.backend = self.backend.upper()
|
||||
if self.backend not in get_args(Backend):
|
||||
raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).")
|
||||
# Check TE args
|
||||
if self.backend == "TE":
|
||||
if not is_transformer_engine_available():
|
||||
raise ValueError(
|
||||
"TransformerEngine is not available. Please either install it, or use the 'MSAMP' backend (if installed)."
|
||||
)
|
||||
if self.use_autocast_during_eval is None:
|
||||
self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL")
|
||||
if self.margin is None:
|
||||
self.margin = int(os.environ.get(env_prefix + "MARGIN", 0))
|
||||
if self.interval is None:
|
||||
self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
|
||||
if self.fp8_format is None:
|
||||
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "E4M3")
|
||||
self.fp8_format = self.fp8_format.upper()
|
||||
if self.fp8_format not in get_args(FP8Format):
|
||||
raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
|
||||
if self.amax_compute_algo is None:
|
||||
self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent")
|
||||
self.amax_compute_algo = self.amax_compute_algo.lower()
|
||||
if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm):
|
||||
raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}")
|
||||
if self.amax_history_len is None:
|
||||
self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024))
|
||||
if self.override_linear_precision is None:
|
||||
fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP")
|
||||
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
|
||||
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
|
||||
self.override_linear_precision = (fprop, dgrad, wgrad)
|
||||
elif self.backend == "MSAMP":
|
||||
if not is_msamp_available():
|
||||
raise ValueError(
|
||||
"MS-AMP is not available. Please either install it, or use the 'TE' backend (if installed)."
|
||||
)
|
||||
if self.opt_level is None:
|
||||
self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2")
|
||||
if self.opt_level not in get_args(OptLevel):
|
||||
raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}")
|
||||
|
||||
@ -534,6 +576,21 @@ class SageMakerDistributedType(str, enum.Enum):
|
||||
MODEL_PARALLEL = "MODEL_PARALLEL"
|
||||
|
||||
|
||||
class FP8BackendType(str, enum.Enum):
|
||||
"""
|
||||
Represents the backend used for FP8.
|
||||
|
||||
Values:
|
||||
|
||||
- **TE** -- using TransformerEngine.
|
||||
- **MSAMP** -- using msamp.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
|
||||
TE = "TE"
|
||||
MSAMP = "MSAMP"
|
||||
|
||||
|
||||
class ComputeEnvironment(str, enum.Enum):
|
||||
"""
|
||||
Represents a type of the compute environment.
|
||||
@ -625,7 +682,7 @@ class LoggerType(BaseEnum):
|
||||
DVCLIVE = "dvclive"
|
||||
|
||||
|
||||
class PrecisionType(BaseEnum):
|
||||
class PrecisionType(str, BaseEnum):
|
||||
"""Represents a type of precision used on floating point values
|
||||
|
||||
Values:
|
||||
@ -1077,12 +1134,13 @@ class DeepSpeedPlugin:
|
||||
ds_config = self.deepspeed_config
|
||||
kwargs = {
|
||||
"fp16.enabled": mixed_precision == "fp16",
|
||||
"bf16.enabled": mixed_precision == "bf16",
|
||||
# When training in fp8, we still rely on bf16 autocast for the core mixed precision
|
||||
"bf16.enabled": mixed_precision in ("bf16", "fp8"),
|
||||
}
|
||||
if mixed_precision == "fp16":
|
||||
if "fp16" not in ds_config:
|
||||
ds_config["fp16"] = {"enabled": True, "auto_cast": True}
|
||||
elif mixed_precision == "bf16":
|
||||
elif mixed_precision in ("bf16", "fp8"):
|
||||
if "bf16" not in ds_config:
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
|
||||
@ -1497,7 +1555,12 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False):
|
||||
"Sets the mixed precision policy for FSDP"
|
||||
mixed_precision_mapping = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
||||
mixed_precision_mapping = {
|
||||
"fp8": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
dtype = mixed_precision
|
||||
if isinstance(mixed_precision, str):
|
||||
dtype = mixed_precision_mapping.get(mixed_precision, None)
|
||||
|
@ -102,7 +102,7 @@ def is_schedulefree_available():
|
||||
|
||||
|
||||
def is_transformer_engine_available():
|
||||
return _is_package_available("transformer_engine")
|
||||
return _is_package_available("transformer_engine", "transformer-engine")
|
||||
|
||||
|
||||
def is_lomo_available():
|
||||
|
@ -27,6 +27,7 @@ from ..commands.config.config_args import SageMakerConfig
|
||||
from ..utils import (
|
||||
DynamoBackend,
|
||||
PrecisionType,
|
||||
is_fp8_available,
|
||||
is_ipex_available,
|
||||
is_mlu_available,
|
||||
is_musa_available,
|
||||
@ -74,6 +75,19 @@ def _get_mpirun_args():
|
||||
return mpi_app, "-f", "-n", "-ppn", ""
|
||||
|
||||
|
||||
def setup_fp8_env(args: argparse.Namespace, current_env: Dict[str, str]):
|
||||
"""
|
||||
Setup the FP8 environment variables.
|
||||
"""
|
||||
prefix = "ACCELERATE_"
|
||||
for arg in vars(args):
|
||||
if arg.startswith("fp8_"):
|
||||
value = getattr(args, arg)
|
||||
if value is not None:
|
||||
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
|
||||
return current_env
|
||||
|
||||
|
||||
def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""
|
||||
Prepares and returns the command list and an environment with the correct simple launcher environment variables.
|
||||
@ -140,6 +154,12 @@ def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> Tuple[List[str]
|
||||
)
|
||||
|
||||
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp8":
|
||||
if not is_fp8_available():
|
||||
raise RuntimeError(
|
||||
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
|
||||
)
|
||||
current_env = setup_fp8_env(args, current_env)
|
||||
|
||||
try:
|
||||
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
|
||||
@ -225,6 +245,12 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
|
||||
raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
|
||||
|
||||
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp8":
|
||||
if not is_fp8_available():
|
||||
raise RuntimeError(
|
||||
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
|
||||
)
|
||||
current_env = setup_fp8_env(args, current_env)
|
||||
|
||||
try:
|
||||
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
|
||||
@ -390,6 +416,12 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
|
||||
|
||||
current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath("."))
|
||||
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp8":
|
||||
if not is_fp8_available():
|
||||
raise RuntimeError(
|
||||
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
|
||||
)
|
||||
current_env = setup_fp8_env(args, current_env)
|
||||
current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower()
|
||||
current_env["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
if args.zero_stage is not None:
|
||||
@ -528,6 +560,12 @@ def prepare_sagemager_args_inputs(
|
||||
"ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic),
|
||||
"ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value,
|
||||
}
|
||||
if args.mixed_precision.lower() == "fp8":
|
||||
if not is_fp8_available():
|
||||
raise RuntimeError(
|
||||
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
|
||||
)
|
||||
environment = setup_fp8_env(args, environment)
|
||||
# configure distribution set up
|
||||
distribution = None
|
||||
if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL:
|
||||
|
@ -17,12 +17,13 @@ A set of basic tensor ops compatible with tpu, gpu, and multigpu
|
||||
|
||||
import pickle
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import update_wrapper, wraps
|
||||
from typing import Any, Mapping
|
||||
|
||||
import torch
|
||||
|
||||
from ..state import PartialState
|
||||
from ..state import AcceleratorState, PartialState
|
||||
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
|
||||
from .dataclasses import DistributedType, TensorInformation
|
||||
from .imports import (
|
||||
@ -843,3 +844,25 @@ def find_device(data):
|
||||
return device
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.device
|
||||
|
||||
|
||||
@contextmanager
|
||||
def GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True):
|
||||
"""
|
||||
Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context
|
||||
manager.
|
||||
"""
|
||||
# We need to use the `AcceleratorState` here since it has access to the deepspeed plugin
|
||||
if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or (
|
||||
AcceleratorState().deepspeed_plugin is not None
|
||||
and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled()
|
||||
):
|
||||
gather_param_context = nullcontext()
|
||||
else:
|
||||
import deepspeed
|
||||
|
||||
gather_param_context = deepspeed.zero.GatheredParameters(
|
||||
params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled
|
||||
)
|
||||
with gather_param_context:
|
||||
yield
|
||||
|
@ -12,9 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .imports import is_fp8_available
|
||||
from .operations import GatheredParameters
|
||||
|
||||
|
||||
if is_fp8_available():
|
||||
@ -29,22 +32,28 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
|
||||
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
|
||||
# Return early if the linear layer weights are not multiples of 16
|
||||
if any(p % 16 != 0 for p in module.weight.shape):
|
||||
return
|
||||
has_bias = module.bias is not None
|
||||
te_module = te.Linear(
|
||||
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
|
||||
)
|
||||
te_module.weight.copy_(module.weight)
|
||||
params_to_gather = [module.weight]
|
||||
if has_bias:
|
||||
te_module.bias.copy_(module.bias)
|
||||
params_to_gather.append(module.bias)
|
||||
|
||||
setattr(model, name, te_module)
|
||||
with GatheredParameters(params_to_gather, modifier_rank=0):
|
||||
if any(p % 16 != 0 for p in module.weight.shape):
|
||||
return
|
||||
te_module = te.Linear(
|
||||
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
|
||||
)
|
||||
te_module.weight.copy_(module.weight)
|
||||
if has_bias:
|
||||
te_module.bias.copy_(module.bias)
|
||||
|
||||
setattr(model, name, te_module)
|
||||
# Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
|
||||
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
|
||||
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
|
||||
te_module.weight.copy_(module.weight)
|
||||
te_module.bias.copy_(module.bias)
|
||||
with GatheredParameters([module.weight, module.bias], modifier_rank=0):
|
||||
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
|
||||
te_module.weight.copy_(module.weight)
|
||||
te_module.bias.copy_(module.bias)
|
||||
|
||||
setattr(model, name, te_module)
|
||||
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
|
||||
@ -82,3 +91,43 @@ def has_transformer_engine_layers(model):
|
||||
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
|
||||
"""
|
||||
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
|
||||
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
|
||||
"""
|
||||
from transformer_engine.pytorch import fp8_autocast
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
enabled = use_during_eval or self.training
|
||||
with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):
|
||||
return model_forward(*args, **kwargs)
|
||||
|
||||
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
||||
forward.__wrapped__ = model_forward
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def apply_fp8_autowrap(model, fp8_recipe_handler):
|
||||
"""
|
||||
Applies FP8 context manager to the model's forward method
|
||||
"""
|
||||
# Import here to keep base imports fast
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
|
||||
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
|
||||
if "fp8_format" in kwargs:
|
||||
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
|
||||
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
|
||||
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
||||
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
|
||||
|
||||
if hasattr(model.forward, "__func__"):
|
||||
model.forward = MethodType(new_forward, model)
|
||||
else:
|
||||
model.forward = new_forward
|
||||
|
||||
return model
|
||||
|
@ -27,9 +27,16 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.state import GradientState, PartialState
|
||||
from accelerate.test_utils import require_bnb, require_multi_gpu, require_non_cpu, slow, torch_device
|
||||
from accelerate.test_utils import (
|
||||
require_bnb,
|
||||
require_multi_gpu,
|
||||
require_non_cpu,
|
||||
require_transformer_engine,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
|
||||
from accelerate.utils import patch_environment
|
||||
from accelerate.utils import FP8RecipeKwargs, patch_environment
|
||||
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model
|
||||
|
||||
|
||||
@ -561,6 +568,22 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
accelerator = Accelerator(cpu=True)
|
||||
_ = accelerator.prepare(sgd)
|
||||
|
||||
@require_transformer_engine
|
||||
def test_can_unwrap_model_te(self):
|
||||
model, optimizer, *_ = create_components()
|
||||
fp8_recipe = FP8RecipeKwargs(backend="TE")
|
||||
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[fp8_recipe])
|
||||
inputs = torch.randn(10, 2).to(torch_device)
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
model(inputs) # sanity check that this works
|
||||
|
||||
model = accelerator.unwrap_model(model, keep_fp32_wrapper=False)
|
||||
model(inputs) # check that this still works
|
||||
|
||||
# check that pickle roundtrip works
|
||||
model_loaded = pickle.loads(pickle.dumps(model))
|
||||
model_loaded(inputs)
|
||||
|
||||
@require_non_cpu
|
||||
def test_can_unwrap_model_fp16(self):
|
||||
# test for a regression introduced in #872
|
||||
|
@ -73,7 +73,7 @@ class AccelerateLauncherTester(unittest.TestCase):
|
||||
execute_subprocess_async(cmd, env=os.environ.copy())
|
||||
|
||||
def test_config_compatibility(self):
|
||||
invalid_configs = ["invalid", "mpi", "sagemaker"]
|
||||
invalid_configs = ["fp8", "invalid", "mpi", "sagemaker"]
|
||||
for config in sorted(self.test_config_path.glob("**/*.yaml")):
|
||||
if any(invalid_config in str(config) for invalid_config in invalid_configs):
|
||||
continue
|
||||
|
26
tests/test_configs/0_34_0_fp8.yaml
Normal file
26
tests/test_configs/0_34_0_fp8.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fp8_config:
|
||||
amax_compute_algorithm: max
|
||||
amax_history_length: 1024
|
||||
backend: TE
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: false
|
||||
use_autocast_during_eval: false
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp8
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
Reference in New Issue
Block a user