mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
* Setup 2023 tooling for quality * Result of styling * Simplify inits and remove isort and flake8 from doc * Puts back isort skip flag
233 lines
9.6 KiB
Python
233 lines
9.6 KiB
Python
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import os
|
|
|
|
# New Code #
|
|
import evaluate
|
|
import torch
|
|
from datasets import load_dataset
|
|
from torch.optim import AdamW
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
|
|
|
from accelerate import Accelerator, DistributedType
|
|
from accelerate.utils import find_executable_batch_size
|
|
|
|
|
|
########################################################################
|
|
# This is a fully working simple example to use Accelerate,
|
|
# specifically showcasing how to combine both the gradient accumulation
|
|
# and automatic batch size finder utilities of Accelerate to perfrom
|
|
# automatic gradient accumulation
|
|
#
|
|
# This example trains a Bert base model on GLUE MRPC
|
|
# in any of the following settings (with the same script):
|
|
# - single CPU or single GPU
|
|
# - multi GPUS (using PyTorch distributed mode)
|
|
# - (multi) TPUs
|
|
# - fp16 (mixed-precision) or fp32 (normal precision)
|
|
#
|
|
# New additions from the base script can be found quickly by
|
|
# looking for the # New Code # tags
|
|
#
|
|
# To run it in each of these various modes, follow the instructions
|
|
# in the readme for examples:
|
|
# https://github.com/huggingface/accelerate/tree/main/examples
|
|
#
|
|
########################################################################
|
|
|
|
EVAL_BATCH_SIZE = 32
|
|
|
|
|
|
def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
|
|
"""
|
|
Creates a set of `DataLoader`s for the `glue` dataset,
|
|
using "bert-base-cased" as the tokenizer.
|
|
|
|
Args:
|
|
accelerator (`Accelerator`):
|
|
An `Accelerator` object
|
|
batch_size (`int`, *optional*):
|
|
The batch size for the train and validation DataLoaders.
|
|
"""
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
|
datasets = load_dataset("glue", "mrpc")
|
|
|
|
def tokenize_function(examples):
|
|
# max_length=None => use the model max length (it's actually the default)
|
|
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
|
return outputs
|
|
|
|
# Apply the method we just defined to all the examples in all the splits of the dataset
|
|
# starting with the main process first:
|
|
with accelerator.main_process_first():
|
|
tokenized_datasets = datasets.map(
|
|
tokenize_function,
|
|
batched=True,
|
|
remove_columns=["idx", "sentence1", "sentence2"],
|
|
)
|
|
|
|
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
|
# transformers library
|
|
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
|
|
|
def collate_fn(examples):
|
|
# On TPU it's best to pad everything to the same length or training will be very slow.
|
|
if accelerator.distributed_type == DistributedType.TPU:
|
|
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
|
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
|
|
|
# Instantiate dataloaders.
|
|
train_dataloader = DataLoader(
|
|
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
|
|
)
|
|
eval_dataloader = DataLoader(
|
|
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
|
|
)
|
|
|
|
return train_dataloader, eval_dataloader
|
|
|
|
|
|
# For testing only
|
|
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
|
|
from accelerate.test_utils.training import mocked_dataloaders
|
|
|
|
get_dataloaders = mocked_dataloaders # noqa: F811
|
|
|
|
|
|
def training_function(config, args):
|
|
# For testing only
|
|
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
|
|
config["num_epochs"] = 2
|
|
# Initialize accelerator
|
|
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
|
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
|
lr = config["lr"]
|
|
num_epochs = int(config["num_epochs"])
|
|
seed = int(config["seed"])
|
|
observed_batch_size = int(config["batch_size"])
|
|
|
|
metric = evaluate.load("glue", "mrpc")
|
|
|
|
# New Code #
|
|
# We use the `find_executable_batch_size` decorator, passing in the desired observed batch size
|
|
# to train on. If a CUDA OOM error occurs, it will retry this loop cutting the batch size in
|
|
# half each time. From this, we can calculate the number of gradient accumulation steps needed
|
|
# and modify the Accelerator object as a result
|
|
@find_executable_batch_size(starting_batch_size=int(observed_batch_size))
|
|
def inner_training_loop(batch_size):
|
|
# Since we need to modify the outside accelerator object, we need to bring it
|
|
# to the local scope
|
|
nonlocal accelerator
|
|
|
|
# We can calculate the number of gradient accumulation steps based on the current
|
|
# batch size vs the starting batch size
|
|
num_gradient_accumulation_steps = observed_batch_size // batch_size
|
|
|
|
# And then set it in the Accelerator directly:
|
|
accelerator.gradient_accumulation_steps = num_gradient_accumulation_steps
|
|
|
|
# Next we need to free all of the stored model references in the Accelerator each time
|
|
accelerator.free_memory()
|
|
|
|
# And set the seed so our results are reproducable each reset
|
|
set_seed(seed)
|
|
|
|
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
|
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
|
|
|
|
# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
|
|
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
|
|
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
|
|
model = model.to(accelerator.device)
|
|
|
|
# Instantiate optimizer
|
|
optimizer = AdamW(params=model.parameters(), lr=lr)
|
|
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)
|
|
|
|
# Instantiate scheduler
|
|
lr_scheduler = get_linear_schedule_with_warmup(
|
|
optimizer=optimizer,
|
|
num_warmup_steps=100,
|
|
num_training_steps=(len(train_dataloader) * num_epochs),
|
|
)
|
|
|
|
# Prepare everything
|
|
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
|
# prepare method.
|
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
|
)
|
|
|
|
# Now we train the model
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
for step, batch in enumerate(train_dataloader):
|
|
# And perform gradient accumulation
|
|
with accelerator.accumulate(model):
|
|
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
|
batch.to(accelerator.device)
|
|
outputs = model(**batch)
|
|
loss = outputs.loss
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
model.eval()
|
|
for step, batch in enumerate(eval_dataloader):
|
|
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
|
batch.to(accelerator.device)
|
|
with torch.no_grad():
|
|
outputs = model(**batch)
|
|
predictions = outputs.logits.argmax(dim=-1)
|
|
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
|
metric.add_batch(
|
|
predictions=predictions,
|
|
references=references,
|
|
)
|
|
|
|
eval_metric = metric.compute()
|
|
# Use accelerator.print to print only on the main process.
|
|
accelerator.print(f"epoch {epoch}:", eval_metric)
|
|
|
|
# New Code #
|
|
# And call it at the end with no arguments
|
|
# Note: You could also refactor this outside of your training loop function
|
|
inner_training_loop()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
|
parser.add_argument(
|
|
"--mixed_precision",
|
|
type=str,
|
|
default=None,
|
|
choices=["no", "fp16", "bf16"],
|
|
help="Whether to use mixed precision. Choose"
|
|
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
|
"and an Nvidia Ampere GPU.",
|
|
)
|
|
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
|
|
args = parser.parse_args()
|
|
# New Code #
|
|
# We modify the starting batch size to be an observed batch size of 256, to guarentee an initial CUDA OOM
|
|
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 256}
|
|
training_function(config, args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|