Files
accelerate/examples/nlp_example.py
Philipp Schmid e93cb7a3bd Launch script on sagemaker (#26)
* fixed loading SageMaker Environment

* added utils and dynamic args parser for sagemaker

* added args converted for sagemaker with type inference

* added launch test

* added sagemaker launcher

* added test

* better print statements

* accelerate as requirements.txt for sagemaker

* make style

* adjusted nlp example and remove action_store since sagemaker cannot handle this

* added documentation side

* added pyaml as dependency

* added doc changes

* reworked doc to .rst to highlight warning and notes better

* quality

* added error raise for store actions and added test

* quality

* Update docs/source/sagemaker.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docs/source/sagemaker.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* moved fp16 from parameter to environment

* Update docs/source/sagemaker.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update docs/source/sagemaker.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update docs/source/sagemaker.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
2021-04-14 16:51:27 -04:00

174 lines
6.8 KiB
Python

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator, DistributedType
from datasets import load_dataset, load_metric
from transformers import (
AdamW,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
set_seed,
)
########################################################################
# This is a fully working simple example to use Accelerate
#
# This example trains a Bert base model on GLUE MRPC
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - (multi) TPUs
# - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, follow the instructions
# in the readme for examples:
# https://github.com/huggingface/accelerate/tree/main/examples
#
########################################################################
MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32
def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
correct_bias = config["correct_bias"]
seed = int(config["seed"])
batch_size = int(config["batch_size"])
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
datasets = load_dataset("glue", "mrpc")
metric = load_metric("glue", "mrpc")
def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs
# Apply the method we just defined to all the examples in all the splits of the dataset
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
# transformers library
tokenized_datasets.rename_column_("label", "labels")
# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE
def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
if accelerator.distributed_type == DistributedType.TPU:
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
)
set_seed(seed)
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
model = model.to(accelerator.device)
# Instantiate optimizer
optimizer = AdamW(params=model.parameters(), lr=lr, correct_bias=correct_bias)
# Prepare everything
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
# Instantiate learning rate scheduler after preparing the training dataloader as the prepare method
# may change its length.
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=len(train_dataloader) * num_epochs,
)
# Now we train the model
for epoch in range(num_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
if step % gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
metric.add_batch(
predictions=accelerator.gather(predictions),
references=accelerator.gather(batch["labels"]),
)
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
def main():
parser = argparse.ArgumentParser(description="Simple example of training script.")
parser.add_argument("--fp16", type=bool, default=False, help="If passed, will use FP16 training.")
parser.add_argument("--cpu", type=bool, default=False, help="If passed, will train on the CPU.")
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "correct_bias": True, "seed": 42, "batch_size": 16}
training_function(config, args)
if __name__ == "__main__":
main()