#!/usr/bin/env python3 # # Copyright (c) Facebook, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This example shows how to use higher to do Model Agnostic Meta Learning (MAML) for few-shot Omniglot classification. For more details see the original MAML paper: https://arxiv.org/abs/1703.03400 This code has been modified from Jackie Loong's PyTorch MAML implementation: https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py Our MAML++ fork and experiments are available at: https://github.com/bamos/HowToTrainYourMAMLPytorch """ import argparse import time import higher import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd from support.omniglot_loaders import OmniglotNShot import torch import torch.nn.functional as F import torch.optim as optim from torch import nn mpl.use("Agg") plt.style.use("bmh") def main(): argparser = argparse.ArgumentParser() argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) argparser.add_argument( "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 ) argparser.add_argument( "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 ) argparser.add_argument("--device", type=str, help="device", default="cuda") argparser.add_argument( "--task-num", "--task_num", type=int, help="meta batch size, namely task num", default=32, ) argparser.add_argument("--seed", type=int, help="random seed", default=1) args = argparser.parse_args() torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) # Set up the Omniglot loader. device = args.device db = OmniglotNShot( "/tmp/omniglot-data", batchsz=args.task_num, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, imgsz=28, device=device, ) # Create a vanilla PyTorch neural network that will be # automatically monkey-patched by higher later. # Before higher, models could *not* be created like this # and the parameters needed to be manually updated and copied # for the updates. net = nn.Sequential( nn.Conv2d(1, 64, 3), nn.BatchNorm2d(64, momentum=1, affine=True), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64, momentum=1, affine=True), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64, momentum=1, affine=True), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), Flatten(), nn.Linear(64, args.n_way), ).to(device) # We will use Adam to (meta-)optimize the initial parameters # to be adapted. meta_opt = optim.Adam(net.parameters(), lr=1e-3) log = [] for epoch in range(100): train(db, net, device, meta_opt, epoch, log) test(db, net, device, epoch, log) plot(log) def train(db, net, device, meta_opt, epoch, log): net.train() n_train_iter = db.x_train.shape[0] // db.batchsz for batch_idx in range(n_train_iter): start_time = time.time() # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? # Initialize the inner optimizer to adapt the parameters to # the support set. n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) qry_losses = [] qry_accs = [] meta_opt.zero_grad() for i in range(task_num): with higher.innerloop_ctx(net, inner_opt, copy_initial_weights=False) as ( fnet, diffopt, ): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. # higher is able to automatically keep copies of # your network's parameters as they are being updated. for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) # The final set of adapted parameters will induce some # final loss and accuracy on the query dataset. # These will be used to update the model's meta-parameters. qry_logits = fnet(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) qry_losses.append(qry_loss.detach()) qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz qry_accs.append(qry_acc) # print([b.shape for b in fnet[1].buffers()]) # Update the model's meta-parameters to optimize the query # losses across all of the tasks sampled in this batch. # This unrolls through the gradient steps. qry_loss.backward() meta_opt.step() qry_losses = sum(qry_losses) / task_num qry_accs = 100.0 * sum(qry_accs) / task_num i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time if batch_idx % 4 == 0: print( f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" ) log.append( { "epoch": i, "loss": qry_losses, "acc": qry_accs, "mode": "train", "time": time.time(), } ) def test(db, net, device, epoch, log): # Crucially in our testing procedure here, we do *not* fine-tune # the model during testing for simplicity. # Most research papers using MAML for this task do an extra # stage of fine-tuning here that should be added if you are # adapting this code for research. net.train() n_test_iter = db.x_test.shape[0] // db.batchsz qry_losses = [] qry_accs = [] for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next("test") task_num, setsz, c_, h, w = x_spt.size() # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) for i in range(task_num): with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as ( fnet, diffopt, ): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) # The query loss and acc induced by these parameters. qry_logits = fnet(x_qry[i]).detach() qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") qry_losses.append(qry_loss.detach()) qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") log.append( { "epoch": epoch + 1, "loss": qry_losses, "acc": qry_accs, "mode": "test", "time": time.time(), } ) def plot(log): # Generally you should pull your plotting code out of your training # script but we are doing it here for brevity. df = pd.DataFrame(log) fig, ax = plt.subplots(figsize=(6, 4)) train_df = df[df["mode"] == "train"] test_df = df[df["mode"] == "test"] ax.plot(train_df["epoch"], train_df["acc"], label="Train") ax.plot(test_df["epoch"], test_df["acc"], label="Test") ax.set_xlabel("Epoch") ax.set_ylabel("Accuracy") ax.set_ylim(70, 100) fig.legend(ncol=2, loc="lower right") fig.tight_layout() fname = "maml-accs.png" print(f"--- Plotting accuracy to {fname}") fig.savefig(fname) plt.close(fig) # Won't need this after this PR is merged in: # https://github.com/pytorch/pytorch/pull/22245 class Flatten(nn.Module): def forward(self, input): return input.view(input.size(0), -1) if __name__ == "__main__": main()