mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
* v1 * More testing, need to try on H100 * Bigger batch for h100 test * test tweak * Fixup all tests! * Bookmark * Fix issues, working now * rm num samples * Uncomment * Give stateful dl end of dl * Make skip DL stateful * Migrate to update_state_dict * try/finally * Add comments to test * rm comment * Document * refactor out for eventual override * Doc nit * Brute force it
330 lines
14 KiB
Python
330 lines
14 KiB
Python
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import os
|
|
import re
|
|
|
|
import numpy as np
|
|
import PIL
|
|
import torch
|
|
from timm import create_model
|
|
from torch.optim.lr_scheduler import OneCycleLR
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
|
|
|
from accelerate import Accelerator, DataLoaderConfiguration
|
|
|
|
|
|
########################################################################
|
|
# This is a fully working simple example to use Accelerate
|
|
#
|
|
# This example trains a ResNet50 on the Oxford-IIT Pet Dataset
|
|
# in any of the following settings (with the same script):
|
|
# - single CPU or single GPU
|
|
# - multi GPUS (using PyTorch distributed mode)
|
|
# - (multi) TPUs
|
|
# - fp16 (mixed-precision) or fp32 (normal precision)
|
|
#
|
|
# To run it in each of these various modes, follow the instructions
|
|
# in the readme for examples:
|
|
# https://github.com/huggingface/accelerate/tree/main/examples
|
|
#
|
|
########################################################################
|
|
|
|
|
|
# Function to get the label from the filename
|
|
def extract_label(fname):
|
|
stem = fname.split(os.path.sep)[-1]
|
|
return re.search(r"^(.*)_\d+\.jpg$", stem).groups()[0]
|
|
|
|
|
|
class PetsDataset(Dataset):
|
|
def __init__(self, file_names, image_transform=None, label_to_id=None):
|
|
self.file_names = file_names
|
|
self.image_transform = image_transform
|
|
self.label_to_id = label_to_id
|
|
|
|
def __len__(self):
|
|
return len(self.file_names)
|
|
|
|
def __getitem__(self, idx):
|
|
fname = self.file_names[idx]
|
|
raw_image = PIL.Image.open(fname)
|
|
image = raw_image.convert("RGB")
|
|
if self.image_transform is not None:
|
|
image = self.image_transform(image)
|
|
label = extract_label(fname)
|
|
if self.label_to_id is not None:
|
|
label = self.label_to_id[label]
|
|
return {"image": image, "label": label}
|
|
|
|
|
|
def training_function(config, args):
|
|
# Initialize accelerator
|
|
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
|
if args.with_tracking:
|
|
accelerator = Accelerator(
|
|
cpu=args.cpu,
|
|
mixed_precision=args.mixed_precision,
|
|
log_with="all",
|
|
project_dir=args.project_dir,
|
|
dataloader_config=dataloader_config,
|
|
)
|
|
else:
|
|
accelerator = Accelerator(
|
|
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
|
|
)
|
|
|
|
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
|
lr = config["lr"]
|
|
num_epochs = int(config["num_epochs"])
|
|
seed = int(config["seed"])
|
|
batch_size = int(config["batch_size"])
|
|
image_size = config["image_size"]
|
|
if not isinstance(image_size, (list, tuple)):
|
|
image_size = (image_size, image_size)
|
|
|
|
# Parse out whether we are saving every epoch or after a certain number of batches
|
|
if hasattr(args.checkpointing_steps, "isdigit"):
|
|
if args.checkpointing_steps == "epoch":
|
|
checkpointing_steps = args.checkpointing_steps
|
|
elif args.checkpointing_steps.isdigit():
|
|
checkpointing_steps = int(args.checkpointing_steps)
|
|
else:
|
|
raise ValueError(
|
|
f"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed."
|
|
)
|
|
else:
|
|
checkpointing_steps = None
|
|
|
|
# We need to initialize the trackers we use, and also store our configuration
|
|
if args.with_tracking:
|
|
run = os.path.split(__file__)[-1].split(".")[0]
|
|
accelerator.init_trackers(run, config)
|
|
|
|
# Grab all the image filenames
|
|
file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(".jpg")]
|
|
|
|
# Build the label correspondences
|
|
all_labels = [extract_label(fname) for fname in file_names]
|
|
id_to_label = list(set(all_labels))
|
|
id_to_label.sort()
|
|
label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}
|
|
|
|
# Set the seed before splitting the data.
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
# Split our filenames between train and validation
|
|
random_perm = np.random.permutation(len(file_names))
|
|
cut = int(0.8 * len(file_names))
|
|
train_split = random_perm[:cut]
|
|
eval_split = random_perm[cut:]
|
|
|
|
# For training we use a simple RandomResizedCrop
|
|
train_tfm = Compose([RandomResizedCrop(image_size, scale=(0.5, 1.0)), ToTensor()])
|
|
train_dataset = PetsDataset(
|
|
[file_names[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id
|
|
)
|
|
|
|
# For evaluation, we use a deterministic Resize
|
|
eval_tfm = Compose([Resize(image_size), ToTensor()])
|
|
eval_dataset = PetsDataset([file_names[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id)
|
|
|
|
# Instantiate dataloaders.
|
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)
|
|
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
|
|
|
|
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
|
model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id))
|
|
|
|
# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
|
|
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
|
|
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
|
|
model = model.to(accelerator.device)
|
|
|
|
# Freezing the base model
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
for param in model.get_classifier().parameters():
|
|
param.requires_grad = True
|
|
|
|
# We normalize the batches of images to be a bit faster.
|
|
mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None].to(accelerator.device)
|
|
std = torch.tensor(model.default_cfg["std"])[None, :, None, None].to(accelerator.device)
|
|
|
|
# Instantiate optimizer
|
|
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr / 25)
|
|
|
|
# Instantiate learning rate scheduler
|
|
lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dataloader))
|
|
|
|
# Prepare everything
|
|
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
|
# prepare method.
|
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
|
)
|
|
# We need to keep track of how many total steps we have iterated over
|
|
overall_step = 0
|
|
# We also need to keep track of the starting epoch so files are named properly
|
|
starting_epoch = 0
|
|
|
|
# Potentially load in the weights and states from a previous save
|
|
if args.resume_from_checkpoint:
|
|
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
|
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
|
accelerator.load_state(args.resume_from_checkpoint)
|
|
path = os.path.basename(args.resume_from_checkpoint)
|
|
else:
|
|
# Get the most recent checkpoint
|
|
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
|
dirs.sort(key=os.path.getctime)
|
|
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
|
# Extract `epoch_{i}` or `step_{i}`
|
|
training_difference = os.path.splitext(path)[0]
|
|
|
|
if "epoch" in training_difference:
|
|
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
|
|
resume_step = None
|
|
else:
|
|
resume_step = int(training_difference.replace("step_", ""))
|
|
starting_epoch = resume_step // len(train_dataloader)
|
|
resume_step -= starting_epoch * len(train_dataloader)
|
|
|
|
# Now we train the model
|
|
for epoch in range(starting_epoch, num_epochs):
|
|
model.train()
|
|
if args.with_tracking:
|
|
total_loss = 0
|
|
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
|
# We need to skip steps until we reach the resumed step
|
|
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
|
overall_step += resume_step
|
|
else:
|
|
# After the first iteration though, we need to go back to the original dataloader
|
|
active_dataloader = train_dataloader
|
|
for batch in active_dataloader:
|
|
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
|
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
|
|
inputs = (batch["image"] - mean) / std
|
|
outputs = model(inputs)
|
|
loss = torch.nn.functional.cross_entropy(outputs, batch["label"])
|
|
# We keep track of the loss at each epoch
|
|
if args.with_tracking:
|
|
total_loss += loss.detach().float()
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
overall_step += 1
|
|
if isinstance(checkpointing_steps, int):
|
|
output_dir = f"step_{overall_step}"
|
|
if overall_step % checkpointing_steps == 0:
|
|
if args.output_dir is not None:
|
|
output_dir = os.path.join(args.output_dir, output_dir)
|
|
accelerator.save_state(output_dir)
|
|
model.eval()
|
|
accurate = 0
|
|
num_elems = 0
|
|
for step, batch in enumerate(eval_dataloader):
|
|
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
|
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
|
|
inputs = (batch["image"] - mean) / std
|
|
with torch.no_grad():
|
|
outputs = model(inputs)
|
|
predictions = outputs.argmax(dim=-1)
|
|
predictions, references = accelerator.gather_for_metrics((predictions, batch["label"]))
|
|
accurate_preds = predictions == references
|
|
num_elems += accurate_preds.shape[0]
|
|
accurate += accurate_preds.long().sum()
|
|
|
|
eval_metric = accurate.item() / num_elems
|
|
# Use accelerator.print to print only on the main process.
|
|
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
|
|
if args.with_tracking:
|
|
accelerator.log(
|
|
{
|
|
"accuracy": 100 * eval_metric,
|
|
"train_loss": total_loss.item() / len(train_dataloader),
|
|
"epoch": epoch,
|
|
},
|
|
step=overall_step,
|
|
)
|
|
if checkpointing_steps == "epoch":
|
|
output_dir = f"epoch_{epoch}"
|
|
if args.output_dir is not None:
|
|
output_dir = os.path.join(args.output_dir, output_dir)
|
|
accelerator.save_state(output_dir)
|
|
|
|
accelerator.end_training()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
|
parser.add_argument("--data_dir", required=True, help="The data folder on disk.")
|
|
parser.add_argument("--fp16", action="store_true", help="If passed, will use FP16 training.")
|
|
parser.add_argument(
|
|
"--mixed_precision",
|
|
type=str,
|
|
default=None,
|
|
choices=["no", "fp16", "bf16", "fp8"],
|
|
help="Whether to use mixed precision. Choose"
|
|
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
|
"and an Nvidia Ampere GPU.",
|
|
)
|
|
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
|
|
parser.add_argument(
|
|
"--checkpointing_steps",
|
|
type=str,
|
|
default=None,
|
|
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default=".",
|
|
help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
|
|
)
|
|
parser.add_argument(
|
|
"--resume_from_checkpoint",
|
|
type=str,
|
|
default=None,
|
|
help="If the training should continue from a checkpoint folder.",
|
|
)
|
|
parser.add_argument(
|
|
"--use_stateful_dataloader",
|
|
action="store_true",
|
|
help="If the dataloader should be a resumable stateful dataloader.",
|
|
)
|
|
parser.add_argument(
|
|
"--with_tracking",
|
|
action="store_true",
|
|
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
|
|
)
|
|
parser.add_argument(
|
|
"--project_dir",
|
|
type=str,
|
|
default="logs",
|
|
help="Location on where to store experiment tracking logs` and relevent project information",
|
|
)
|
|
args = parser.parse_args()
|
|
config = {"lr": 3e-2, "num_epochs": 3, "seed": 42, "batch_size": 64, "image_size": 224}
|
|
training_function(config, args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|