Files
pytorch/functorch/examples/maml_omniglot/maml-omniglot-transforms.py
Aaron Gokaslan 5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00

257 lines
8.6 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
from support.omniglot_loaders import OmniglotNShot
from torch.func import vmap, grad, functional_call
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse
import time
import functools
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
plt.style.use('bmh')
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task-num', '--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network.
inplace_relu = True
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, args.n_way)).to(device)
net.train()
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
log = []
for epoch in range(100):
train(db, net, device, meta_opt, epoch, log)
test(db, net, device, epoch, log)
plot(log)
# Trains a model for n_inner_iter using the support and returns a loss
# using the query.
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
params = dict(net.named_parameters())
buffers = dict(net.named_buffers())
querysz = x_qry.size(0)
def compute_loss(new_params, buffers, x, y):
logits = functional_call(net, (new_params, buffers), x)
loss = F.cross_entropy(logits, y)
return loss
new_params = params
for _ in range(n_inner_iter):
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
new_params = {k: new_params[k] - g * 1e-1 for k, g, in grads.items()}
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = functional_call(net, (new_params, buffers), x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(
dim=1) == y_qry).sum() / querysz
return qry_loss, qry_acc
def train(db, net, device, meta_opt, epoch, log):
params = dict(net.named_parameters())
buffers = dict(net.named_buffers())
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
n_inner_iter = 5
meta_opt.zero_grad()
# In parallel, trains one model per task. There is a support (x, y)
# for each task and a query (x, y) for each task.
compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)
# Compute the maml loss by summing together the returned losses.
qry_losses.sum().backward()
meta_opt.step()
qry_losses = qry_losses.detach().sum() / task_num
qry_accs = 100. * qry_accs.sum() / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
params = dict(net.named_parameters())
buffers = dict(net.named_buffers())
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
for i in range(task_num):
new_params = params
for _ in range(n_inner_iter):
spt_logits = functional_call(net, (new_params, buffers), x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params.values())
new_params = {k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)}
# The query loss and acc induced by these parameters.
qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
if __name__ == '__main__':
main()