Enable FSDP & Deepspeed + FP8 (#2983)

* Working version rebased from main

* kwargs

* Clean

* Fix more nits

* Fin

* Delay autocast flag

* Enable FP8 autocast during eval only if specified

* Fin

* Rm comment

* All done

* Zero3 works!

* Let the wrapper come off during unwrap_model

* Add import check

* Migrate all to benchmarks folder and make TE import check work

* Add readme

* Add README to benchmarks folder

* Update CLI to now include fp8 args

* Add test config for 0_34

* Finish adding to config yaml

* Write docs

* Expound docs w/ FP8

* Add to toctree
This commit is contained in:
Zach Mueller
2024-08-14 14:57:01 -04:00
committed by GitHub
parent 851cf34351
commit a452327e8e
31 changed files with 1446 additions and 123 deletions

View File

@ -1,46 +1,5 @@
# Big model inference benchmarks
# Benchmarks
Running inference with Accelerate on big models.
The folders below contain suites to test various functionalities in Accelerate.
## Setup
These benchmarks use the `transformers` library:
```bash
pip install transformers
```
To reproduce or test a new setup, run
```py
python inference_acc.py model_name
```
This script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.
To force a different `torch_dtype` than the one in the config: `--torch_dtype xxx`.
If you get an error linked to disk offload, you need to add the option `--disk-offload`
## Results
On a setup with two Titan RTXs (24GB of RAM) and 32GB of RAM, we get the following benchmarks (T0pp does not run in float16, which is why it's not included).
| Model | Model load time | Generation time | dtype | GPU 0 use | GPU 1 use | CPU use | Disk offload |
|:-----:|:---------------:|:---------------:|:-----:|:---------:|:---------:|:-------:|:------------:|
| GPT-J-6B | 8.7s | 0.05s per token | float16 | 11.7GB | 0GB | 0GB | no |
| GPT-J-6B | 12.4s | 0.06s per token | float32 | 21.9GB | 1.5GB | 0GB | no |
| GPT-Neo-X-20B | 30.9s | 0.08s per token | float16 | 21.5GB | 18GB | 0GB | no |
| GPT-Neo-X-20B | 78.2s | 10.72s per token | float32 | 20.3GB | 22.7 GB | 24.4GB | yes |
| T0pp (11B) | 29.4s | 0.05s per token | float32 | 21.1GB | 21.3GB | 0GB | no |
| OPT-30B | 34.5s | 2.37s per token | float16 | 20.7GB | 22.3GB | 14.1GB | no |
| OPT-30B | 112.3s | 33.9s per token | float32 | 20.2GB | 21.2GB | 23.5GB | yes |
Note on the results:
- using two GPUs instead of one does not slow down generation
- using CPU offload slows down a bit (see OPT-30b)
- using disk offload slows down a lot (need to implement prefetching)
You will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:
- peak GPU memory is exactly the size of the model put on a given GPU
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.
See their relevant README.md's for more information.

View File

@ -0,0 +1,46 @@
# Big model inference benchmarks
Running inference with Accelerate on big models.
## Setup
These benchmarks use the `transformers` library:
```bash
pip install transformers
```
To reproduce or test a new setup, run
```py
python inference_acc.py model_name
```
This script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.
To force a different `torch_dtype` than the one in the config: `--torch_dtype xxx`.
If you get an error linked to disk offload, you need to add the option `--disk-offload`
## Results
On a setup with two Titan RTXs (24GB of RAM) and 32GB of RAM, we get the following benchmarks (T0pp does not run in float16, which is why it's not included).
| Model | Model load time | Generation time | dtype | GPU 0 use | GPU 1 use | CPU use | Disk offload |
|:-----:|:---------------:|:---------------:|:-----:|:---------:|:---------:|:-------:|:------------:|
| GPT-J-6B | 8.7s | 0.05s per token | float16 | 11.7GB | 0GB | 0GB | no |
| GPT-J-6B | 12.4s | 0.06s per token | float32 | 21.9GB | 1.5GB | 0GB | no |
| GPT-Neo-X-20B | 30.9s | 0.08s per token | float16 | 21.5GB | 18GB | 0GB | no |
| GPT-Neo-X-20B | 78.2s | 10.72s per token | float32 | 20.3GB | 22.7 GB | 24.4GB | yes |
| T0pp (11B) | 29.4s | 0.05s per token | float32 | 21.1GB | 21.3GB | 0GB | no |
| OPT-30B | 34.5s | 2.37s per token | float16 | 20.7GB | 22.3GB | 14.1GB | no |
| OPT-30B | 112.3s | 33.9s per token | float32 | 20.2GB | 21.2GB | 23.5GB | yes |
Note on the results:
- using two GPUs instead of one does not slow down generation
- using CPU offload slows down a bit (see OPT-30b)
- using disk offload slows down a lot (need to implement prefetching)
You will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:
- peak GPU memory is exactly the size of the model put on a given GPU
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.

12
benchmarks/fp8/Dockerfile Normal file
View File

@ -0,0 +1,12 @@
FROM nvcr.io/nvidia/pytorch:24.07-py3
RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git
RUN cd accelerate && \
pip install -e . && \
cd benchmarks/fp8
RUN /bin/bash

30
benchmarks/fp8/README.md Normal file
View File

@ -0,0 +1,30 @@
# FP8 Benchmarks
Comparing and running [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) FP8 with accelerate
## Overview
This repo provides scripts which compare native TransformerEngine model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:
* Single GPU training (`non_distributed.py`)
* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
* Fully Sharded Data Parallelism (`fsdp.py`)
* DeepSpeed ZeRO 1-3 (`deepspeed.py`)
To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `TransformerEngine` manually.
## Running:
You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.
For single GPU, run it via `python`:
```bash
python non_distributed.py
```
For the rest, run it via `accelerate launch`:
```bash
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
```

143
benchmarks/fp8/ddp.py Normal file
View File

@ -0,0 +1,143 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
This particular script verifies this for DDP training.
"""
import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from torch.nn.parallel import DistributedDataParallel as DDP
from transformer_engine.common.recipe import DelayedScaling
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model
MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")
def train_baseline():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator()
device = accelerator.device
model.to(device)
# Convert the model to TE
old_named_params = get_named_parameters(model)
with torch.no_grad():
convert_model(model)
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
new_named_params = get_named_parameters(model)
# Convert the model to DDP
device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
model = DDP(model, device_ids=device_ids, output_device=output_device)
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
for _ in range(2):
for batch in train_dataloader:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results
def train_integration():
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
AcceleratorState()._reset_state(True)
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)
model, optimizer = accelerator.prepare(model, optimizer)
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
for _ in range(2):
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results
if __name__ == "__main__":
baseline_not_trained, baseline_trained = train_baseline()
accelerator_not_trained, accelerator_trained = train_integration()
assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
torch.distributed.destroy_process_group()

View File

@ -0,0 +1,189 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
This particular script verifies this for DDP training.
"""
from unittest.mock import patch
import deepspeed
import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from transformer_engine.common.recipe import DelayedScaling
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model
MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")
def train_baseline(zero_stage: int = 1):
# This forces transformers to think Zero-3 Init should be used
with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
mock.return_value = zero_stage == 3
set_seed(42)
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)
# Convert the model to TE
old_named_params = get_named_parameters(model)
with torch.no_grad():
convert_model(model)
new_named_params = get_named_parameters(model)
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
import numpy as np
config = {
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": zero_stage,
"offload_optimizer": {"device": "none", "nvme_path": None},
"offload_param": {"device": "none", "nvme_path": None},
"stage3_gather_16bit_weights_on_model_save": False,
},
"gradient_clipping": 1.0,
"steps_per_print": np.inf,
"bf16": {"enabled": True},
"fp16": {"enabled": False},
"zero_allow_untested_optimizer": True,
}
(
model,
optimizer,
_,
_,
) = deepspeed.initialize(
model=model,
optimizer=optimizer,
config_params=config,
)
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
model_outputs = []
data = []
for _ in range(2):
for batch in train_dataloader:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
outputs = model(**batch)
data.append(batch.to("cpu"))
model_outputs.append(outputs.logits.to("cpu"))
loss = outputs.loss
model.backward(loss)
model.step()
for _ in range(accelerator.num_processes):
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.destroy()
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results, model_outputs, data
def train_integration(zero_stage: int = 1):
set_seed(42)
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
AcceleratorState()._reset_state(True)
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=zero_stage,
zero3_init_flag=zero_stage == 3,
)
accelerator = Accelerator(
mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin
)
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
model_outputs = []
data = []
for _ in range(2):
for batch in train_dataloader:
outputs = model(**batch)
data.append(batch.to("cpu"))
model_outputs.append(outputs.logits.to("cpu"))
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.destroy()
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results, model_outputs, data
if __name__ == "__main__":
# for zero_stage in [1, 2, 3]:
zero_stage = 1
baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage)
assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
torch.distributed.destroy_process_group()

115
benchmarks/fp8/fp8_utils.py Normal file
View File

@ -0,0 +1,115 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def get_dataloaders(model_name: str, batch_size: int = 16):
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
datasets = load_dataset("glue", "mrpc")
def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs
# Apply the method we just defined to all the examples in all the splits of the dataset
# starting with the main process first:
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
# transformers library
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
def collate_fn(examples):
return tokenizer.pad(
examples,
padding="longest",
pad_to_multiple_of=16, # Specific for FP8
return_tensors="pt",
)
# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"],
shuffle=False,
collate_fn=collate_fn,
batch_size=16,
drop_last=True,
)
return train_dataloader, eval_dataloader
def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):
"""
Returns a tuple of:
- Model
- Optimizer
- Train dataloader (prepared)
- Eval dataloader (prepared)
- LR Scheduler
Suitable for training on the MRPC dataset
"""
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from accelerate import Accelerator
if accelerator is None:
accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained(model_name)
train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
optimizer = AdamW(model.parameters(), lr=0.0001)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=len(train_dataloader) * 2,
)
train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
def get_named_parameters(model):
"""
Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted
from parallel)
"""
from accelerate.utils import extract_model_from_parallel
model = extract_model_from_parallel(model)
return {n: p for n, p in model.named_parameters()}
def evaluate_model(model, dataloader, metric, accelerator=None):
"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
model.eval()
for step, batch in enumerate(dataloader):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
if accelerator is not None and accelerator.num_processes > 1:
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(predictions=predictions, references=references)
return metric.compute()

160
benchmarks/fp8/fsdp.py Normal file
View File

@ -0,0 +1,160 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
This particular script verifies this for FSDP training.
"""
from functools import partial
import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformer_engine.common.recipe import DelayedScaling
from transformers.models.bert import BertLayer
from accelerate import Accelerator
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model
MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")
FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})
def train_baseline():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator()
device = accelerator.device
model.to(device)
# Convert the model to TE
old_named_params = get_named_parameters(model)
with torch.no_grad():
convert_model(model)
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
new_named_params = get_named_parameters(model)
# Convert the model to FSDP
model = FSDP(
model,
use_orig_params=True,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
auto_wrap_policy=FSDP_WRAP_POLICY,
)
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
for _ in range(2):
for batch in train_dataloader:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results
def train_integration():
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
AcceleratorState()._reset_state(True)
fsdp_plugin = FSDPPlugin(
auto_wrap_policy=FSDP_WRAP_POLICY,
use_orig_params=True,
mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
)
accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers)
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)
model, optimizer = accelerator.prepare(model, optimizer)
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
for _ in range(2):
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results
if __name__ == "__main__":
baseline_not_trained, baseline_trained = train_baseline()
accelerator_not_trained, accelerator_trained = train_integration()
assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
torch.distributed.destroy_process_group()

View File

@ -0,0 +1,131 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
This particular script verifies this for single GPU training.
"""
import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from transformer_engine.common.recipe import DelayedScaling
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model
MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")
def train_baseline():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
# Convert the model to TE
old_named_params = get_named_parameters(model)
with torch.no_grad():
convert_model(model)
new_named_params = get_named_parameters(model)
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)
model.to("cuda")
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
for batch in train_dataloader:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
batch = batch.to("cuda")
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC)
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results
def train_integration():
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
AcceleratorState()._reset_state(True)
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers)
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
trained_model_results = evaluate_model(model, eval_dataloader, METRIC)
assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
return base_model_results, trained_model_results
if __name__ == "__main__":
baseline_not_trained, baseline_trained = train_baseline()
accelerator_not_trained, accelerator_trained = train_integration()
assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'

View File

@ -112,6 +112,8 @@
title: Distributed inference with big models
- local: package_reference/kwargs
title: Kwargs handlers
- local: package_reference/fp8
title: FP8 Functionality
- local: package_reference/utilities
title: Utility functions and classes
- local: package_reference/megatron_lm

View File

@ -145,10 +145,11 @@ values. They can also be passed in manually.
The following arguments are useful for fine-tuning how available hardware should be used
* `--mixed_precision {no,fp16,bf16}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.
* `--mixed_precision {no,fp16,bf16,fp8}` (`str`) -- Whether or not to use mixed precision training. Choose between FP16 and BF16 (bfloat16) training. BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.
* `--num_processes NUM_PROCESSES` (`int`) -- The total number of processes to be launched in parallel.
* `--num_machines NUM_MACHINES` (`int`) -- The total number of machines used in this training.
* `--num_cpu_threads_per_process NUM_CPU_THREADS_PER_PROCESS` (`int`) -- The number of CPU threads per process. Can be tuned for optimal performance.
* `--enable_cpu_affinity` (`bool`) -- Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.
**Training Paradigm Arguments**:
@ -165,19 +166,26 @@ The following arguments are only useful when `multi_gpu` is passed or multi-gpu
* `--gpu_ids` (`str`) -- What GPUs (by id) should be used for training on this machine as a comma-seperated list
* `--same_network` (`bool`) -- Whether all machines used for multinode training exist on the same local network.
* `--machine_rank MACHINE_RANK` (`int`) -- The rank of the machine on which this script is launched.
* `--main_process_ip MAIN_PROCESS_IP` (`str`) -- The IP address of the machine of rank 0.
* `--main_process_port MAIN_PROCESS_PORT` (`int`) -- The port to use to communicate with the machine of rank 0.
* `--rdzv_backend` (`str`) -- The rendezvous method to use, such as "static" or "c10d"
* `--machine_rank` (`int`) -- The rank of the machine on which this script is launched.
* `--main_process_ip` (`str`) -- The IP address of the machine of rank 0.
* `--main_process_port` (`int`) -- The port to use to communicate with the machine of rank 0.
* `-t`, `--tee` (`str`) -- Tee std streams into a log file and also to console.
* `--log_dir` (`str`) -- Base directory to use for log files when using torchrun/torch.distributed.run as launcher. Use with --tee to redirect std streams info log files.
* `--role` (`str`) -- User-defined role for the workers.
* `--rdzv_backend` (`str`) -- The rendezvous method to use, such as 'static' (the default) or 'c10d'
* `--rdzv_conf` (`str`) -- Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).
* `--max_restarts` (`int`) -- Maximum number of worker group restarts before failing.
* `--monitor_interval` (`float`) -- Interval, in seconds, to monitor the state of workers.
* `--monitor_interval` (`int`) -- Interval, in seconds, to monitor the state of workers.
**TPU Arguments**:
The following arguments are only useful when `tpu` is passed or TPU training is configured through `accelerate config`:
* `--main_training_function MAIN_TRAINING_FUNCTION` (`str`) -- The name of the main function to be executed in your script.
* `--tpu_cluster` (`bool`) -- Whether to use a GCP TPU pod for training.
* `--tpu_use_sudo` (`bool`) -- Whether to use `sudo` when running the TPU training script in each pod.
* `--vm` (`str`) -- List of single Compute VM instance names. If not provided we assume usage of instance groups. For TPU pods.
* `--env` (`str`) -- List of environment variables to set on the Compute VM instances. For TPU pods.
* `--main_training_function` (`str`) -- The name of the main function to be executed in your script (only for TPU training).
* `--downcast_bf16` (`bool`) -- Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.
**DeepSpeed Arguments**:
@ -188,6 +196,7 @@ The following arguments are only useful when `use_deepspeed` is passed or `deeps
* `--zero_stage` (`int`) -- DeepSpeed's ZeRO optimization stage.
* `--offload_optimizer_device` (`str`) -- Decides where (none|cpu|nvme) to offload optimizer states.
* `--offload_param_device` (`str`) -- Decides where (none|cpu|nvme) to offload parameters.
* `--offload_optimizer_nvme_path` (`str`) -- Decides Nvme Path to offload optimizer states.
* `--gradient_accumulation_steps` (`int`) -- No of gradient_accumulation_steps used in your training script.
* `--gradient_clipping` (`float`) -- Gradient clipping value used in your training script.
* `--zero3_init_flag` (`str`) -- Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with DeepSpeed ZeRO Stage-3.
@ -196,6 +205,7 @@ The following arguments are only useful when `use_deepspeed` is passed or `deeps
* `--deepspeed_exclusion_filter` (`str`) -- DeepSpeed exclusion filter string when using mutli-node setup.
* `--deepspeed_inclusion_filter` (`str`) -- DeepSpeed inclusion filter string when using mutli-node setup.
* `--deepspeed_multinode_launcher` (`str`) -- DeepSpeed multi-node launcher to use.
* `--deepspeed_moe_layer_cls_names` (`str`) -- comma-separated list of transformer MoE layer class names (case-sensitive) to wrap, e.g, `MixtralSparseMoeBlock` `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock`
**Fully Sharded Data Parallelism Arguments**:
@ -210,8 +220,9 @@ The following arguments are only useful when `use_fsdp` is passed or Fully Shard
* `--fsdp_state_dict_type` (`str`) -- FSDP's state dict type.
* `--fsdp_forward_prefetch` (`str`) -- FSDP forward prefetch.
* `--fsdp_use_orig_params` (`str`) -- If True, allows non-uniform `requires_grad` mixed in a FSDP unit.
* `--fsdp_cpu_ram_efficient_loading` (`str`) - If true, only the first process loads the pretrained model checkoint while all other processes have empty weights. When using this, `--fsdp_sync_module_states` needs to True.
* `--fsdp_sync_module_states` (`str`) - If true, each individually wrapped FSDP unit will broadcast module parameters from rank 0.
* `--fsdp_cpu_ram_efficient_loading` (`str`) -- If true, only the first process loads the pretrained model checkoint while all other processes have empty weights. When using this, `--fsdp_sync_module_states` needs to True.
* `--fsdp_sync_module_states` (`str`) -- If true, each individually wrapped FSDP unit will broadcast module parameters from rank 0.
* `--fsdp_activation_checkpointing` (`bool`) -- Decides Whether intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder
**Megatron-LM Arguments**:
@ -225,6 +236,18 @@ The following arguments are only useful when `use_megatron_lm` is passed or Mega
* `--megatron_lm_use_distributed_optimizer` (``) -- Decides Whether (true|false) to use distributed optimizer which shards optimizer state and gradients across Data Parallel (DP) ranks.
* `--megatron_lm_gradient_clipping` (``) -- Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable).
**FP8 Arguments**:
* `--fp8_backend` (`str`) -- Choose a backend to train with FP8 (`te` or `msamp`)
* `--fp8_use_autocast_during_eval` (`bool`) -- Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.
* `--fp8_margin` (`int`) -- The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).
* `--fp8_interval` (`int`) -- The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).
* `--fp8_format` (`str`) -- The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).
* `--fp8_amax_history_len` (`int`) -- The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).
* `--fp8_amax_compute_algo` (`str`) -- The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).
* `--fp8_override_linear_precision` (`Tuple[bool, bool, bool]`) -- Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
* `--fp8_opt_level` (`str`) -- What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed)
**AWS SageMaker Arguments**:
The following arguments are only useful when training in SageMaker

View File

@ -0,0 +1,28 @@
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FP8 Functionality
Below are functions and classes relative to the underlying FP8 implementation
[[autodoc]] utils.FP8RecipeKwargs
[[autodoc]] utils.convert_model
[[autodoc]] utils.has_transformer_engine_layers
[[autodoc]] utils.contextual_fp8_autocast
[[autodoc]] utils.apply_fp8_autowrap

View File

@ -39,7 +39,7 @@ from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="fp8")
```
By default, if `MS-AMP` is available in your environment, 🤗 Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`]:
By default, if `MS-AMP` is available in your environment, 🤗 Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`] or clarify it in your config `yaml`/during `accelerate launch`:
```{python}
from accelerate import Accelerator
@ -50,6 +50,19 @@ kwargs = [FP8RecipeKwargs(backend="msamp")]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```
```{yaml}
mixed_precision: fp8
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
use_autocast_during_eval: false
```
## Configuring MS-AMP
Of the two, `MS-AMP` is traditionally the easier one to configure as there is only a single argument: the optimization level.
@ -68,6 +81,17 @@ kwargs = [FP8RecipeKwargs(backend="msamp", optimization_level="O2")]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```
Or during `accelerate launch` via `--fp8_backend=msamp --fp8_opt_level=O2`
Similarly this can be set in your `config.yaml`:
```{yaml}
mixed_precision: fp8
fp8_config:
backend: MSAMP
opt_level: O2
```
## Configuring TransformersEngine
TransformersEngine has much more available for customizing how and what FP8 calculations are performed. A full list of supported arguments and what they mean are available in [NVIDIA's documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html), however they are restated as part of [`FP8KwargsHandler`]'s docstring for your convenience.
@ -83,6 +107,35 @@ kwargs = [FP8RecipeKwargs(backend="te", ...)]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```
Or during `accelerate launch` via `--fp8_backend=te ...`. Use `accelerate launch --fp8_backend=te -h` to see relevent arguments.
Similarly this can be set in your `config.yaml`:
```{yaml}
mixed_precision: fp8
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
use_autocast_during_eval: false
```
## Example Zoo
We have examples showcasing training with FP8 both with accelerate and its underlying implementation available in the accelerate repo.
Currently we support scripts showcasing:
* Single GPU
* Distributed Data Parallelism (Multi-GPU)
* Fully Sharded Data Parallelism
* DeepSpeed ZeRO 1 through 3
Find out more [here](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8)
## Further Reading
To learn more about training in FP8 please check out the following resources:

View File

@ -68,6 +68,7 @@ from .utils import (
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
compare_versions,
@ -78,7 +79,6 @@ from .utils import (
gather_object,
get_mixed_precision_context_manager,
get_pretty_name,
has_transformer_engine_layers,
is_bf16_available,
is_deepspeed_available,
is_ipex_available,
@ -391,10 +391,15 @@ class Accelerator:
**kwargs,
)
if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self.fp8_recipe_handler = FP8RecipeKwargs()
self.delayed_fp8_autocast = False
if self.fp8_recipe_handler is not None:
# We already check if FP8 is available during `self.state`
if self.state.mixed_precision != "fp8":
if self.state.mixed_precision != "fp8" and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.")
self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in (
DistributedType.MULTI_GPU,
@ -1290,7 +1295,8 @@ class Accelerator:
# If we're dealing with device placement, this deals with that by...
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
if tpu_should_fix_optimizer:
# 1. grabbing old model parameters
old_named_params = self._get_named_parameters(*args)
@ -1299,6 +1305,8 @@ class Accelerator:
args = self._prepare_ipex_or_xpu(*args)
elif self.device.type == "xpu" and is_xpu_available():
args = self._prepare_ipex_or_xpu(*args)
if self.fp8_recipe_handler is not None and self.fp8_recipe_handler.backend == "TE":
args = self._prepare_te(*args)
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
elif self.distributed_type == DistributedType.MEGATRON_LM:
@ -1312,8 +1320,7 @@ class Accelerator:
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
)
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
if tpu_should_fix_optimizer:
# 2. grabbing new model parameters
new_named_params = self._get_named_parameters(*result)
# 3. building a map from the first to the second
@ -1384,24 +1391,8 @@ class Accelerator:
model.forward = convert_outputs_to_fp32(new_forward)
# We prepare fp8 after, allowing for bf16 autocast to happen first
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
# Import here to keep base imports fast
import transformer_engine.common.recipe as te_recipe
from transformer_engine.pytorch import fp8_autocast
if not has_transformer_engine_layers(model):
with torch.no_grad():
convert_model(model)
model._converted_to_transformer_engine = True
kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
# If we are in DDP or FSDP, we delay `autocast` until after FSDP/DDP has been initialized
# to make use of the process group
if not self.delayed_fp8_autocast:
model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward)
if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
@ -1455,6 +1446,7 @@ class Accelerator:
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
@ -1578,9 +1570,7 @@ class Accelerator:
model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast
if self.delayed_fp8_autocast:
model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=model.process_group)(
model.forward
)
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if not is_torch_version(">=", "2.0"):
@ -1588,6 +1578,42 @@ class Accelerator:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model
def _prepare_te(self, *args):
if not is_msamp_available():
raise ImportError(
"`transformer_engine` was not found on your system. Please ensure that `transformer_engine` is installed"
)
model, optimizer = None, None
num_models, num_optimizers = 0, 0
result = [obj for obj in args]
for obj in result:
if isinstance(obj, torch.nn.Module):
model = obj
num_models += 1
elif isinstance(obj, (torch.optim.Optimizer)):
optimizer = obj
num_optimizers += 1
if optimizer is None and model is None:
return result
elif optimizer is None or model is None:
raise ValueError(
"You must pass a model and an optimizer together to `accelerate.prepare()` when using TransformerEngine."
)
elif num_models > 1 or num_optimizers > 1:
raise ValueError(
f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with TransformerEngine."
)
old_named_params = self._get_named_parameters(model)
with torch.no_grad():
convert_model(model)
new_named_params = self._get_named_parameters(model)
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
# We need to switch the optimizer params to the new params *after* the model is wrapped in FSDP
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]
return result
def _prepare_deepspeed(self, *args):
import deepspeed
@ -1696,6 +1722,9 @@ class Accelerator:
)
if model is not None:
# If we are using FP8, we need to apply the autowrap now
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
deepspeed_plugin.set_moe_leaf_modules(model)
# deal with config keys that use `auto` value and rely on model's hidden_size

View File

@ -20,10 +20,13 @@ from ...utils import (
ComputeEnvironment,
DistributedType,
is_deepspeed_available,
is_fp8_available,
is_mlu_available,
is_mps_available,
is_msamp_available,
is_musa_available,
is_npu_available,
is_transformer_engine_available,
is_transformers_available,
is_xpu_available,
)
@ -42,6 +45,7 @@ from .config_utils import (
_ask_options,
_convert_distributed_mode,
_convert_dynamo_backend,
_convert_fp8_backend,
_convert_mixed_precision,
_convert_yes_no_to_bool,
)
@ -616,6 +620,7 @@ def get_cluster_input():
error_message="Please enter yes or no.",
)
fp8_config = None
if distributed_type == DistributedType.XLA:
mixed_precision = "no"
main_training_function = _ask_field(
@ -697,10 +702,86 @@ def get_cluster_input():
mixed_precision = None
else:
mixed_precision = _ask_options(
"Do you wish to use FP16 or BF16 (mixed precision)?",
"Do you wish to use mixed precision?",
["no", "fp16", "bf16", "fp8"],
_convert_mixed_precision,
)
if mixed_precision == "fp8":
if not is_fp8_available():
raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
fp8_config = {}
fp8_config["backend"] = _ask_options(
"Which FP8 backend do you want to use?",
["te", "msamp"],
_convert_fp8_backend,
)
if fp8_config["backend"] == "TE":
if not is_transformer_engine_available():
raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
fp8_config["use_autocast_during_eval"] = _ask_field(
"Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
)
fp8_config["margin"] = _ask_field(
"What margin should be used for gradient scaling? [0]: ",
int,
default=0,
)
fp8_config["interval"] = _ask_field(
"What interval should be used for for how often the scaling factor is recomputed? [1]: ",
int,
default=1,
)
fp8_config["fp8_format"] = _ask_options(
"Which weight format should be used?",
["E4M3", "HYBRID"],
lambda x: "E4M3" if x == 0 else "HYBRID",
default=0,
)
fp8_config["amax_history_length"] = _ask_field(
"What length of history should be used for the amax scaling factor computation? [1024]: ",
int,
default=1024,
)
fp8_config["amax_compute_algorithm"] = _ask_options(
"Which algorithm should be used for the amax scaling factor computation?",
["max", "most_recent"],
lambda x: "max" if x == 0 else "most_recent",
default=0,
)
fp8_config["override_linear_precision"] = _ask_field(
"Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
)
if fp8_config["override_linear_precision"]:
fprop = _ask_field(
"Should `fprop` be executed in higher precision? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
)
dgrad = _ask_field(
"Should `dgrad` be executed in higher precision? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
)
wgrad = _ask_field(
"Should `wgrad` be executed in higher precision? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
)
fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
elif fp8_config["backend"] == "MSAMP":
if not is_msamp_available():
raise ValueError("MSAMP was selected, but it is not installed on this machine.")
fp8_config["optimization_level"] = _ask_options(
"Which optimization level should be used?",
["O1", "O2"],
lambda x: "O1" if x == 0 else "O2",
default=1,
)
if use_dynamo and mixed_precision == "no" and not use_cpu:
print(
@ -724,6 +805,7 @@ def get_cluster_input():
main_process_ip=main_process_ip,
main_process_port=main_process_port,
main_training_function=main_training_function,
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
megatron_lm_config=megatron_lm_config,

View File

@ -83,11 +83,19 @@ class BaseConfig:
def to_dict(self):
result = self.__dict__
# For serialization, it's best to convert Enums to strings (or their underlying value type).
for key, value in result.items():
def _convert_enums(value):
if isinstance(value, Enum):
result[key] = value.value
if isinstance(value, dict) and not bool(value):
result[key] = None
return value.value
if isinstance(value, dict):
if not bool(value):
return None
for key1, value1 in value.items():
value[key1] = _convert_enums(value1)
return value
for key, value in result.items():
result[key] = _convert_enums(value)
result = {k: v for k, v in result.items() if v is not None}
return result
@ -184,6 +192,8 @@ class ClusterConfig(BaseConfig):
main_training_function: str = "main"
enable_cpu_affinity: bool = False
# args for FP8 training
fp8_config: dict = None
# args for deepspeed_plugin
deepspeed_config: dict = None
# args for fsdp
@ -221,6 +231,8 @@ class ClusterConfig(BaseConfig):
self.ipex_config = {}
if self.mpirun_config is None:
self.mpirun_config = {}
if self.fp8_config is None:
self.fp8_config = {}
return super().__post_init__()

View File

@ -20,6 +20,7 @@ from ...utils.dataclasses import (
ComputeEnvironment,
DistributedType,
DynamoBackend,
FP8BackendType,
PrecisionType,
SageMakerDistributedType,
)
@ -90,6 +91,11 @@ def _convert_sagemaker_distributed_mode(value):
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
def _convert_fp8_backend(value):
value = int(value)
return FP8BackendType(["TE", "MSAMP"][value])
def _convert_yes_no_to_bool(value):
return {"yes": True, "no": False}[value.lower()]

View File

@ -53,6 +53,7 @@ from accelerate.utils import (
prepare_sagemager_args_inputs,
prepare_simple_launcher_cmd_env,
prepare_tpu,
str_to_bool,
)
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
@ -74,11 +75,14 @@ options_to_group = {
"use_deepspeed": "DeepSpeed Arguments",
"use_fsdp": "FSDP Arguments",
"use_megatron_lm": "Megatron-LM Arguments",
"fp8_backend": "FP8 Arguments",
}
def clean_option(option):
"Finds all cases of - after the first two characters and changes them to _"
if "fp8_backend" in option:
option = "--fp8_backend"
if option.startswith("--"):
return option[2:].replace("-", "_")
@ -214,7 +218,6 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
)
# Dynamo arguments
resource_args.add_argument(
"--dynamo_backend",
@ -642,6 +645,68 @@ def launch_command_parser(subparsers=None):
"(useful only when `use_megatron_lm` flag is passed).",
)
# FP8 arguments
fp8_args = parser.add_argument_group(
"FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
)
fp8_args.add_argument(
"--fp8_backend",
type=str,
choices=["te", "msamp"],
help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
)
fp8_args.add_argument(
"--fp8_use_autocast_during_eval",
default=False,
action="store_true",
help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
)
fp8_args.add_argument(
"--fp8_margin",
type=int,
default=0,
help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
"--fp8_interval",
type=int,
default=1,
help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
"--fp8_format",
type=str,
default="E4M3",
choices=["E4M3", "HYBRID"],
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
"--fp8_amax_history_len",
type=int,
default=1024,
help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
"--fp8_amax_compute_algo",
type=str,
default="most_recent",
choices=["max", "most_recent"],
help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
"--fp8_override_linear_precision",
type=lambda x: tuple(map(str_to_bool, x.split(","))),
default=(False, False, False),
help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-seperated string of booleans (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
"--fp8_opt_level",
type=str,
default="O2",
choices=["O1", "O2"],
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
)
# AWS arguments
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
aws_args.add_argument(

View File

@ -41,6 +41,7 @@ from .testing import (
require_torch_min_version,
require_torchvision,
require_tpu,
require_transformer_engine,
require_xpu,
skip,
slow,

View File

@ -53,6 +53,7 @@ from ..utils import (
is_torch_version,
is_torch_xla_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
@ -404,6 +405,14 @@ def require_import_timer(test_case):
return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)
def require_transformer_engine(test_case):
"""
Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers
engine isn't installed
"""
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
)

View File

@ -149,6 +149,7 @@ from .offload import (
)
from .operations import (
CannotPadNestedTensorWarning,
GatheredParameters,
broadcast,
broadcast_object_list,
concatenate,
@ -250,4 +251,9 @@ from .other import (
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
from .torch_xla import install_xla
from .tqdm import tqdm
from .transformer_engine import convert_model, has_transformer_engine_layers
from .transformer_engine import (
apply_fp8_autowrap,
contextual_fp8_autocast,
convert_model,
has_transformer_engine_layers,
)

View File

@ -30,8 +30,15 @@ from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple
import torch
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY
from .environment import str_to_bool
from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available
from .environment import parse_flag_from_env, str_to_bool
from .imports import (
is_cuda_available,
is_mlu_available,
is_msamp_available,
is_npu_available,
is_transformer_engine_available,
is_xpu_available,
)
from .versions import compare_versions
@ -297,8 +304,11 @@ class FP8RecipeKwargs(KwargsHandler):
```
Args:
backend (`str`, *optional*, defaults to "msamp"):
Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine).
backend (`str`, *optional*):
Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). If not passed,
will use whichever is available in the environment, prioritizing MS-AMP.
use_autocast_during_eval (`bool`, *optional*, default to `False`):
Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`.
margin (`int`, *optional*, default to 0):
The margin to use for the gradient scaling.
interval (`int`, *optional*, default to 1):
@ -323,28 +333,60 @@ class FP8RecipeKwargs(KwargsHandler):
available currently).
"""
backend: Backend = "MSAMP"
opt_level: OptLevel = "O2"
margin: int = 0
interval: int = 1
fp8_format: FP8Format = "E4M3"
amax_history_len: int = 1
amax_compute_algo: AmaxComputeAlgorithm = "most_recent"
override_linear_precision: Tuple[bool, bool, bool] = (False, False, False)
backend: Backend = None
use_autocast_during_eval: bool = None
opt_level: OptLevel = None
margin: int = None
interval: int = None
fp8_format: FP8Format = None
amax_history_len: int = None
amax_compute_algo: AmaxComputeAlgorithm = None
override_linear_precision: Tuple[bool, bool, bool] = None
def __post_init__(self):
if self.backend.upper() not in get_args(Backend):
raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).")
env_prefix = "ACCELERATE_FP8_"
default_backend = "msamp" if is_msamp_available() else "te"
if self.backend is None:
self.backend = os.environ.get(env_prefix + "BACKEND", default_backend)
self.backend = self.backend.upper()
if self.backend not in get_args(Backend):
raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).")
# Check TE args
if self.backend == "TE":
if not is_transformer_engine_available():
raise ValueError(
"TransformerEngine is not available. Please either install it, or use the 'MSAMP' backend (if installed)."
)
if self.use_autocast_during_eval is None:
self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL")
if self.margin is None:
self.margin = int(os.environ.get(env_prefix + "MARGIN", 0))
if self.interval is None:
self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
if self.fp8_format is None:
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "E4M3")
self.fp8_format = self.fp8_format.upper()
if self.fp8_format not in get_args(FP8Format):
raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
if self.amax_compute_algo is None:
self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent")
self.amax_compute_algo = self.amax_compute_algo.lower()
if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm):
raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}")
if self.amax_history_len is None:
self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024))
if self.override_linear_precision is None:
fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP")
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
self.override_linear_precision = (fprop, dgrad, wgrad)
elif self.backend == "MSAMP":
if not is_msamp_available():
raise ValueError(
"MS-AMP is not available. Please either install it, or use the 'TE' backend (if installed)."
)
if self.opt_level is None:
self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2")
if self.opt_level not in get_args(OptLevel):
raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}")
@ -534,6 +576,21 @@ class SageMakerDistributedType(str, enum.Enum):
MODEL_PARALLEL = "MODEL_PARALLEL"
class FP8BackendType(str, enum.Enum):
"""
Represents the backend used for FP8.
Values:
- **TE** -- using TransformerEngine.
- **MSAMP** -- using msamp.
"""
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
TE = "TE"
MSAMP = "MSAMP"
class ComputeEnvironment(str, enum.Enum):
"""
Represents a type of the compute environment.
@ -625,7 +682,7 @@ class LoggerType(BaseEnum):
DVCLIVE = "dvclive"
class PrecisionType(BaseEnum):
class PrecisionType(str, BaseEnum):
"""Represents a type of precision used on floating point values
Values:
@ -1077,12 +1134,13 @@ class DeepSpeedPlugin:
ds_config = self.deepspeed_config
kwargs = {
"fp16.enabled": mixed_precision == "fp16",
"bf16.enabled": mixed_precision == "bf16",
# When training in fp8, we still rely on bf16 autocast for the core mixed precision
"bf16.enabled": mixed_precision in ("bf16", "fp8"),
}
if mixed_precision == "fp16":
if "fp16" not in ds_config:
ds_config["fp16"] = {"enabled": True, "auto_cast": True}
elif mixed_precision == "bf16":
elif mixed_precision in ("bf16", "fp8"):
if "bf16" not in ds_config:
ds_config["bf16"] = {"enabled": True}
@ -1497,7 +1555,12 @@ class FullyShardedDataParallelPlugin:
def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False):
"Sets the mixed precision policy for FSDP"
mixed_precision_mapping = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
mixed_precision_mapping = {
"fp8": torch.bfloat16,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = mixed_precision
if isinstance(mixed_precision, str):
dtype = mixed_precision_mapping.get(mixed_precision, None)

View File

@ -102,7 +102,7 @@ def is_schedulefree_available():
def is_transformer_engine_available():
return _is_package_available("transformer_engine")
return _is_package_available("transformer_engine", "transformer-engine")
def is_lomo_available():

View File

@ -27,6 +27,7 @@ from ..commands.config.config_args import SageMakerConfig
from ..utils import (
DynamoBackend,
PrecisionType,
is_fp8_available,
is_ipex_available,
is_mlu_available,
is_musa_available,
@ -74,6 +75,19 @@ def _get_mpirun_args():
return mpi_app, "-f", "-n", "-ppn", ""
def setup_fp8_env(args: argparse.Namespace, current_env: Dict[str, str]):
"""
Setup the FP8 environment variables.
"""
prefix = "ACCELERATE_"
for arg in vars(args):
if arg.startswith("fp8_"):
value = getattr(args, arg)
if value is not None:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env
def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict[str, str]]:
"""
Prepares and returns the command list and an environment with the correct simple launcher environment variables.
@ -140,6 +154,12 @@ def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> Tuple[List[str]
)
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
if args.mixed_precision.lower() == "fp8":
if not is_fp8_available():
raise RuntimeError(
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
)
current_env = setup_fp8_env(args, current_env)
try:
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
@ -225,6 +245,12 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
if args.mixed_precision.lower() == "fp8":
if not is_fp8_available():
raise RuntimeError(
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
)
current_env = setup_fp8_env(args, current_env)
try:
dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
@ -390,6 +416,12 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath("."))
current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
if args.mixed_precision.lower() == "fp8":
if not is_fp8_available():
raise RuntimeError(
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
)
current_env = setup_fp8_env(args, current_env)
current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower()
current_env["ACCELERATE_USE_DEEPSPEED"] = "true"
if args.zero_stage is not None:
@ -528,6 +560,12 @@ def prepare_sagemager_args_inputs(
"ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic),
"ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value,
}
if args.mixed_precision.lower() == "fp8":
if not is_fp8_available():
raise RuntimeError(
"FP8 is not available on this machine. Please ensure that either Transformer Engine or MSAMP is installed."
)
environment = setup_fp8_env(args, environment)
# configure distribution set up
distribution = None
if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL:

View File

@ -17,12 +17,13 @@ A set of basic tensor ops compatible with tpu, gpu, and multigpu
import pickle
import warnings
from contextlib import contextmanager, nullcontext
from functools import update_wrapper, wraps
from typing import Any, Mapping
import torch
from ..state import PartialState
from ..state import AcceleratorState, PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import (
@ -843,3 +844,25 @@ def find_device(data):
return device
elif isinstance(data, torch.Tensor):
return data.device
@contextmanager
def GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True):
"""
Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context
manager.
"""
# We need to use the `AcceleratorState` here since it has access to the deepspeed plugin
if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or (
AcceleratorState().deepspeed_plugin is not None
and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled()
):
gather_param_context = nullcontext()
else:
import deepspeed
gather_param_context = deepspeed.zero.GatheredParameters(
params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled
)
with gather_param_context:
yield

View File

@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import MethodType
import torch.nn as nn
from .imports import is_fp8_available
from .operations import GatheredParameters
if is_fp8_available():
@ -29,22 +32,28 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
for name, module in model.named_children():
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
# Return early if the linear layer weights are not multiples of 16
if any(p % 16 != 0 for p in module.weight.shape):
return
has_bias = module.bias is not None
te_module = te.Linear(
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
)
te_module.weight.copy_(module.weight)
params_to_gather = [module.weight]
if has_bias:
te_module.bias.copy_(module.bias)
params_to_gather.append(module.bias)
setattr(model, name, te_module)
with GatheredParameters(params_to_gather, modifier_rank=0):
if any(p % 16 != 0 for p in module.weight.shape):
return
te_module = te.Linear(
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
)
te_module.weight.copy_(module.weight)
if has_bias:
te_module.bias.copy_(module.bias)
setattr(model, name, te_module)
# Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
with GatheredParameters([module.weight, module.bias], modifier_rank=0):
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
setattr(model, name, te_module)
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
@ -82,3 +91,43 @@ def has_transformer_engine_layers(model):
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
return True
return False
def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
"""
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
"""
from transformer_engine.pytorch import fp8_autocast
def forward(self, *args, **kwargs):
enabled = use_during_eval or self.training
with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):
return model_forward(*args, **kwargs)
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
forward.__wrapped__ = model_forward
return forward
def apply_fp8_autowrap(model, fp8_recipe_handler):
"""
Applies FP8 context manager to the model's forward method
"""
# Import here to keep base imports fast
import transformer_engine.common.recipe as te_recipe
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
if hasattr(model.forward, "__func__"):
model.forward = MethodType(new_forward, model)
else:
model.forward = new_forward
return model

View File

@ -27,9 +27,16 @@ from torch.utils.data import DataLoader, TensorDataset
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import require_bnb, require_multi_gpu, require_non_cpu, slow, torch_device
from accelerate.test_utils import (
require_bnb,
require_multi_gpu,
require_non_cpu,
require_transformer_engine,
slow,
torch_device,
)
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
from accelerate.utils import patch_environment
from accelerate.utils import FP8RecipeKwargs, patch_environment
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model
@ -561,6 +568,22 @@ class AcceleratorTester(AccelerateTestCase):
accelerator = Accelerator(cpu=True)
_ = accelerator.prepare(sgd)
@require_transformer_engine
def test_can_unwrap_model_te(self):
model, optimizer, *_ = create_components()
fp8_recipe = FP8RecipeKwargs(backend="TE")
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[fp8_recipe])
inputs = torch.randn(10, 2).to(torch_device)
model, optimizer = accelerator.prepare(model, optimizer)
model(inputs) # sanity check that this works
model = accelerator.unwrap_model(model, keep_fp32_wrapper=False)
model(inputs) # check that this still works
# check that pickle roundtrip works
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)
@require_non_cpu
def test_can_unwrap_model_fp16(self):
# test for a regression introduced in #872

View File

@ -73,7 +73,7 @@ class AccelerateLauncherTester(unittest.TestCase):
execute_subprocess_async(cmd, env=os.environ.copy())
def test_config_compatibility(self):
invalid_configs = ["invalid", "mpi", "sagemaker"]
invalid_configs = ["fp8", "invalid", "mpi", "sagemaker"]
for config in sorted(self.test_config_path.glob("**/*.yaml")):
if any(invalid_config in str(config) for invalid_config in invalid_configs):
continue

View File

@ -0,0 +1,26 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
use_autocast_during_eval: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp8
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false