[functorch] a lot of files

This commit is contained in:
Richard Zou
2021-04-20 14:28:02 -07:00
committed by Jon Janzen
parent 608a932c1a
commit 93888a3779
47 changed files with 12610 additions and 0 deletions

4
functorch/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
build/
dist/
functorch.egg-info/
*__pycache__*

1
functorch/examples/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
cifar10/

View File

@ -0,0 +1,2 @@
catch throw
r cifar10_transforms.py

View File

@ -0,0 +1,436 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
# This is based off of zou3519/pytorch:expand_weights.
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
sample_grads = tuple(param.grad_sample for param in model.parameters())
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
# step 4: assign the new grads, delete the sample grads
for param, param_grad in zip(model.parameters(), grads):
param.grad = param_grad
del param.grad_sample
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads (provided by pytorch)
# loss.backward() populates the grad_sample attribute of each param
with model.compute_per_sample_grads(batch_size=images.shape[0]):
output = model(images)
loss = criterion(output, target)
loss.backward()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
# Opacus implements this but I wrote a custom one to show how this would work.
# This deletes the grad_sample attributes and populates the grad attributes
clip_and_accumulate_and_add_noise(
model, args.max_per_sample_grad_norm, args.sigma)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,405 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# compute gradient and do SGD step
loss.backward()
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
if not args.disable_dp:
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
args.delta
)
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
)
else:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
default=256,
type=int,
metavar="N",
help="mini-batch size (default: 256), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,475 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from make_functional import make_functional, load_weights
from torch.eager_transforms import vmap, grad_with_value
from functools import partial
# from resnet import resnet18
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
sample_grads = tuple(param.grad_sample for param in model.parameters())
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
# step 4: assign the new grads, delete the sample grads
for param, param_grad in zip(model.parameters(), grads):
param.grad = param_grad
del param.grad_sample
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
# In order to use functional vmap+grad, we need to be able to
# pass the weights to a model.
weights, func_model, descriptors = make_functional(model)
# To use vmap+grad to compute per-sample-grads, the forward pass
# must be re-formulated on a single example.
# We use the `grad` operator to compute forward+backward on a single example,
# and finally `vmap` to do forward+backward on multiple examples.
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# `grad(f)` is a functional API that returns a function `f'` that
# computes gradients by running both the forward and backward pass.
# We want to extract some intermediate
# values from the computation (i.e. the loss and output).
#
# To extract the loss, we use the `grad_with_value` API, that returns the
# gradient of the weights w.r.t. the loss and the loss.
#
# To extract the output, we use the `has_aux=True` flag.
# `has_aux=True` assumes that `f` returns a tuple of two values,
# where the first is to be differentiated and the second "auxiliary value"
# is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
# the loss, and the auxiliary value.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
def packed(weights, images, target):
grads, loss, output = grads_loss_output(weights, images, target)
result = tuple([*grads, loss, output])
return result
result = vmap(partial(packed, weights))(images, target)
sample_grads, sample_loss, output, = result[:-2], result[-2], result[-1]
loss = sample_loss.mean()
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that they're easier to manipulate
load_weights(model, descriptors, weights)
for grad_sample, weight in zip(sample_grads, model.parameters()):
weight.grad_sample = grad_sample.detach()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
model, args.max_per_sample_grad_norm, args.sigma)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,438 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from functools import partial
from make_functional import make_functional, load_weights
# NB: The following might not exist depending on what you're using
from torch import vmap
from functional_utils import grad, grad_with_value
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
# TODO(rzou): does the group norm work correctly?
weights, func_model, descriptors = make_functional(model)
[weight.requires_grad_(False) for weight in weights]
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
target = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, target)
return loss, output.squeeze(0)
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean(0)
# Step 2: Clip the per-sample-grads and sum them to form grads
# TODO(rzou): Right now we just sum the grads. Instead we need to clip them.
for sample_grad, weight in zip(grads, weights):
weight_grad = sample_grad.sum(0)
weight.grad = weight_grad
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
[weight.requires_grad_(True) for weight in weights]
load_weights(model, descriptors, weights, as_params=True)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
if not args.disable_dp:
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
args.delta
)
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
)
else:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
default=256,
type=int,
metavar="N",
help="mini-batch size (default: 256), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,491 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch import vmap
from make_functional import make_functional, load_weights
from functional_utils import grad, grad_with_value
from functools import partial
# from resnet import resnet18
class CudaMemoryLeakCheck():
def __init__(self, name):
self.name = name
# initialize context & RNG to prevenikkt false positive detections
# when the test is the first to initialize those
from torch.testing._internal.common_cuda import initialize_cuda_context_rng
initialize_cuda_context_rng()
@staticmethod
def get_cuda_memory_usage():
# we don't need CUDA synchronize because the statistics are not tracked at
# actual freeing, but at when marking the block as free.
num_devices = torch.cuda.device_count()
import gc
gc.collect()
return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
def __enter__(self):
self.befores = self.get_cuda_memory_usage()
def __exit__(self, exec_type, exec_value, traceback):
# Don't check for leaks if an exception was thrown
if exec_type is not None:
return
afters = self.get_cuda_memory_usage()
for i, (before, after) in enumerate(zip(self.befores, afters)):
if after - before == 0:
continue
raise RuntimeError(f'{self.name} leaked {after-before} bytes')
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
return grads
def train(args, model, train_loader, optimizer, epoch, device):
use_prototype = False
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
weights, func_model, descriptors = make_functional(model)
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# grad_with_value(f) returns a function that returns (1) the grad and
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
# where the first is to be differentiated and the second is not to be
# differentiated and further adds a 3rd output.
#
# We need to use `grad_with_value(..., has_aux=True)` because we do
# some analyses on the returned loss and output.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
sample_grads, args.max_per_sample_grad_norm, args.sigma)
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
load_weights(model, descriptors, weights)
for weight_grad, weight in zip(grads, model.parameters()):
weight.grad = weight_grad.detach()
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,457 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch import vmap
from make_functional import make_functional, load_weights
from functional_utils import grad, grad_with_value
from functools import partial
# from resnet import resnet18
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
return grads
def train(args, model, train_loader, optimizer, epoch, device):
use_prototype = False
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
weights, func_model, descriptors = make_functional(model)
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# grad_with_value(f) returns a function that returns (1) the grad and
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
# where the first is to be differentiated and the second is not to be
# differentiated and further adds a 3rd output.
#
# We need to use `grad_with_value(..., has_aux=True)` because we do
# some analyses on the returned loss and output.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
sample_grads, args.max_per_sample_grad_norm, args.sigma)
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
load_weights(model, descriptors, weights)
for weight_grad, weight in zip(grads, model.parameters()):
weight.grad = weight_grad.detach()
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,456 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch import vmap
from make_functional import make_functional, load_weights
from functional_utils import grad, grad_with_value
from functools import partial
# from resnet import resnet18
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
return grads
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
weights, func_model, descriptors = make_functional(model)
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# grad_with_value(f) returns a function that returns (1) the grad and
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
# where the first is to be differentiated and the second is not to be
# differentiated and further adds a 3rd output.
#
# We need to use `grad_with_value(..., has_aux=True)` because we do
# some analyses on the returned loss and output.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
sample_grads, args.max_per_sample_grad_norm, args.sigma)
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
load_weights(model, descriptors, weights)
for weight_grad, weight in zip(grads, model.parameters()):
weight.grad = weight_grad.detach()
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,71 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
import copy
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p for p in orig_params)
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_set_nested_attr(mod, name.split("."), p)
def make_functional(model: nn.Module):
weights, descriptors = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors

View File

@ -0,0 +1,71 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
import copy
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p.detach().requires_grad_() for p in orig_params)
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_set_nested_attr(mod, name.split("."), p)
def make_functional(model: nn.Module):
weights, descriptors = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors

View File

@ -0,0 +1,122 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional, grad_with_value, vmap
# Adapted from http://willwhitney.com/parallel-training-jax.html
# GOAL: Demonstrate that it is possible to use eager-mode vmap
# to parallelize training over models.
# NB: this code runs off of a branch on zou3519/pytorch:dynlayer
DEVICE = 'cpu'
# Step 1: Make some spirals
def make_spirals(n_samples, noise_std=0., rotations=1.):
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
rs = ts ** 0.5
thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
labels = (signs > 0).to(torch.long)
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
points = torch.stack([xs, ys], dim=1)
return points, labels
points, labels = make_spirals(100, noise_std=0.05)
# Step 2: Define two-layer MLP and loss function
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
loss_fn = nn.NLLLoss()
# Step 3: Make the model functional(!!) and define a training function.
# NB: this mechanism doesn't exist in PyTorch today, but we want it to:
# https://github.com/pytorch/pytorch/issues/49171
weights, func_model, _ = make_functional(MLPClassifier().to(DEVICE))
def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
output = func_model(weights, (batch,))
loss = loss_fn(output, targets)
return loss
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
# so we are going to re-implement SGD here.
new_weights = []
with torch.no_grad():
for grad_weight, weight in zip(grad_weights, weights):
new_weights.append(weight - grad_weight * lr)
# NB: return looks weird because torch.vmap must return Tensors
return (loss, *new_weights)
def unpack(train_result):
return train_result[0], train_result[1:]
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
def step4():
global weights
for i in range(2000):
loss, weights = unpack(train_step_fn(weights, points, labels))
if i % 100 == 0:
print(loss)
step4()
# Step 5: We're ready for multiple models. Let's define an init_fn
# that, given a number of models, returns to us all of the weights.
def init_fn(num_models):
models = tuple(MLPClassifier() for _ in range(num_models))
weights = tuple(make_functional(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing
def step6():
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
batched_weights = init_fn(num_models=2)
for i in range(2000):
loss, batched_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))
if i % 200 == 0:
print(loss)
step6()
# Step 7: Now, the flaw with step 6 is that we were training on the same exact
# data. This can lead to all of the models in the ensemble overfitting in the
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
# applies is to randomly subset the data in a way that the models do not recieve
# exactly the same data in each training step!
# Because the goal of this doc is to show that we can use eager-mode vmap to
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html
# does, we used the following additional items that PyTorch does not have:
# 1. NN module functional API that turns a module into a (state, state_less_fn) pair
# 2. Functional optimizers
# 3. A "functional" grad API (that effectively wraps autograd.grad)
# 4. Composability between the functional grad API and torch.vmap.

View File

@ -0,0 +1,2 @@
catch throw
r maml-omniglot-transforms.py

View File

@ -0,0 +1,3 @@
omniglot/
maml-accs.png

View File

@ -0,0 +1,17 @@
# Omniglot MAML examples
In this directory we've provided some examples of traning omniglot that reproduce the experiments from [the original MAML paper](https://arxiv.org/abs/1703.03400).
They can be run via `python {filename}`.
`maml-omniglot-higher.py` uses the [facebookresearch/higher](https://github.com/facebookresearch/higher) metalearning package and is the reference implementation. It runs all of its tasks sequentially.
`maml-omniglot-transforms.py` uses an experimental vmap (and functional grad) prototype. It runs all of its tasks in parallel. In theory this should lead to some speedups, but we haven't finished implementing all the rules for vmap that would actually make training faster.
`maml-omniglot-ptonly.py` is an implementation of `maml-omniglot-transforms.py` that runs all of its tasks sequentially (and also doesn't use the higher package).
The prototype vmap used for these experiments currently run off of a branch.
We'd love some feedback on the prototype and encourage folks to try it out.
It's a bit difficult to install, but here are some options:
1. If you're on the FAIR cluster, we can share a path to a conda environment
2. We are looking into building binaries using our branch and shipping them.

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
import argparse
import time
import typing
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import higher
from support.omniglot_loaders import OmniglotNShot
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network that will be
# automatically monkey-patched by higher later.
# Before higher, models could *not* be created like this
# and the parameters needed to be manually updated and copied
# for the updates.
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
log = []
for epoch in range(100):
train(db, net, device, meta_opt, epoch, log)
test(db, net, device, epoch, log)
plot(log)
def train(db, net, device, meta_opt, epoch, log):
net.train()
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# print([b.shape for b in fnet[1].buffers()])
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
net.train()
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,270 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
import argparse
import time
import typing
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.eager_transforms import make_functional_with_buffers
import higher
from support.omniglot_loaders import OmniglotNShot
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network that will be
# automatically monkey-patched by higher later.
# Before higher, models could *not* be created like this
# and the parameters needed to be manually updated and copied
# for the updates.
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
net.train()
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(params, lr=1e-3)
log = []
for epoch in range(100):
train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
test(db, [params, buffers, fnet], device, epoch, log)
plot(log)
def train(db, net, device, meta_opt, epoch, log):
params, buffers, fnet = net
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 5
# inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
new_params = params
for _ in range(n_inner_iter):
spt_logits = fnet(new_params, buffers, (x_spt[i],))
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params, create_graph=True)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(new_params, buffers, (x_qry[i],))
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
[params, buffers, fnet] = net
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
for i in range(task_num):
new_params = params
for _ in range(n_inner_iter):
spt_logits = fnet(new_params, buffers, (x_spt[i],))
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
# The query loss and acc induced by these parameters.
qry_logits = fnet(new_params, buffers, (x_qry[i],)).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,276 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
import argparse
import time
import typing
import functools
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from functorch import make_functional_with_buffers, vmap, grad
import higher
from support.omniglot_loaders import OmniglotNShot
# torch._C._debug_only_display_vmap_fallback_warnings(True)
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network.
# TODO: The prototype doesn't support in-place relu (and some other
# in-place operations. That can be fixed.)
inplace_relu = False
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
net.train()
# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
# and can be called with `fnet(params, buffers, args, kwargs)`
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(params, lr=1e-3)
log = []
for epoch in range(100):
train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
test(db, [params, buffers, fnet], device, epoch, log)
plot(log)
# Trains a model for n_inner_iter using the support and returns a loss
# using the query.
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
params, buffers, fnet = net
querysz = x_qry.size(0)
def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, (x,))
loss = F.cross_entropy(logits, y)
return loss
new_params = params
for _ in range(n_inner_iter):
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(new_params, buffers, (x_qry,))
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(
dim=1) == y_qry).sum() / querysz
return qry_loss, qry_acc
def train(db, net, device, meta_opt, epoch, log):
params, buffers, fnet = net
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
n_inner_iter = 5
meta_opt.zero_grad()
# In parallel, trains one model per task. There is a support (x, y)
# for each task and a query (x, y) for each task.
compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)
# Compute the maml loss by summing together the returned losses.
qry_losses.sum().backward()
meta_opt.step()
qry_losses = qry_losses.detach().sum() / task_num
qry_accs = 100. * qry_accs.sum() / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
[params, buffers, fnet] = net
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
for i in range(task_num):
new_params = params
for _ in range(n_inner_iter):
spt_logits = fnet(new_params, buffers, (x_spt[i],))
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
# The query loss and acc induced by these parameters.
qry_logits = fnet(new_params, buffers, (x_qry[i],)).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,303 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation:
# https://github.com/dragen1860/MAML-Pytorch
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
import os
import os.path
import errno
class Omniglot(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
'''
The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
'''
def __init__(self, root, transform=None, target_transform=None,
download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
if not self._check_exists():
if download:
self.download()
else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items)
def __getitem__(self, index):
filename = self.all_items[index][0]
img = str.join('/', [self.all_items[index][2], filename])
target = self.idx_classes[self.all_items[index][1]]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.all_items)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
def download(self):
from six.moves import urllib
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('== Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")
def find_classes(root_dir):
retour = []
for (root, dirs, files) in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r = root.split('/')
lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print("== Found %d items " % len(retour))
return retour
def index_classes(items):
idx = {}
for i in items:
if i[1] not in idx:
idx[i[1]] = len(idx)
print("== Found %d classes" % len(idx))
return idx
class OmniglotNShot:
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
"""
Different from mnistNShot, the
:param root:
:param batchsz: task num
:param n_way:
:param k_shot:
:param k_qry:
:param imgsz:
"""
self.resize = imgsz
self.device = device
if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
# if root/data.npy does not exist, just download it
self.x = Omniglot(
root, download=True,
transform=transforms.Compose(
[lambda x: Image.open(x).convert('L'),
lambda x: x.resize((imgsz, imgsz)),
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x/255.]),
)
temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
for (img, label) in self.x:
if label in temp.keys():
temp[label].append(img)
else:
temp[label] = [img]
self.x = []
for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs
self.x.append(np.array(imgs))
# as different class may have different number of imgs
self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total]
# each character contains 20 imgs
print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
temp = [] # Free memory
# save all dataset into npy file.
np.save(os.path.join(root, 'omniglot.npy'), self.x)
print('write into omniglot.npy.')
else:
# if data.npy exists, just load it.
self.x = np.load(os.path.join(root, 'omniglot.npy'))
print('load from omniglot.npy.')
# [1623, 20, 84, 84, 1]
# TODO: can not shuffle here, we must keep training and test set distinct!
self.x_train, self.x_test = self.x[:1200], self.x[1200:]
# self.normalization()
self.batchsz = batchsz
self.n_cls = self.x.shape[0] # 1623
self.n_way = n_way # n way
self.k_shot = k_shot # k shot
self.k_query = k_query # k query
assert (k_shot + k_query) <=20
# save pointer of current read batch in total cache
self.indexes = {"train": 0, "test": 0}
self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached
print("DB: train", self.x_train.shape, "test", self.x_test.shape)
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached
"test": self.load_data_cache(self.datasets["test"])}
def normalization(self):
"""
Normalizes our data, to have a mean of 0 and sdt of 1
"""
self.mean = np.mean(self.x_train)
self.std = np.std(self.x_train)
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
# print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
self.x_train = (self.x_train - self.mean) / self.std
self.x_test = (self.x_test - self.mean) / self.std
self.mean = np.mean(self.x_train)
self.std = np.std(self.x_train)
self.max = np.max(self.x_train)
self.min = np.min(self.x_train)
# print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
def load_data_cache(self, data_pack):
"""
Collects several batches data for N-shot learning
:param data_pack: [cls_num, 20, 84, 84, 1]
:return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
"""
# take 5 way 1 shot as example: 5 * 1
setsz = self.k_shot * self.n_way
querysz = self.k_query * self.n_way
data_cache = []
# print('preload next 50 caches of batchsz of batch.')
for sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
for i in range(self.batchsz): # one batch means one set
x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
for j, cur_class in enumerate(selected_cls):
selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
# meta-training and meta-test
x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
y_spt.append([j for _ in range(self.k_shot)])
y_qry.append([j for _ in range(self.k_query)])
# shuffle inside a batch
perm = np.random.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = np.random.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
# append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
x_spts.append(x_spt)
y_spts.append(y_spt)
x_qrys.append(x_qry)
y_qrys.append(y_qry)
# [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz)
# [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz)
x_spts, y_spts, x_qrys, y_qrys = [
torch.from_numpy(z).to(self.device) for z in
[x_spts, y_spts, x_qrys, y_qrys]
]
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
return data_cache
def next(self, mode='train'):
"""
Gets next batch from the dataset with name.
:param mode: The name of the splitting (one of "train", "val", "test")
:return:
"""
# update cache if indexes is larger cached num
if self.indexes[mode] >= len(self.datasets_cache[mode]):
self.indexes[mode] = 0
self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
next_batch = self.datasets_cache[mode][self.indexes[mode]]
self.indexes[mode] += 1
return next_batch

View File

@ -0,0 +1,113 @@
import math
import random
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
def net(x, params):
x = F.linear(x, params[0], params[1])
x = F.relu(x)
x = F.linear(x, params[2], params[3])
x = F.relu(x)
x = F.linear(x, params[4], params[5])
return x
params = [
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
torch.Tensor(1).zero_().requires_grad_(),
]
opt = torch.optim.Adam(params, lr=1e-3)
alpha = 0.1
K = 20
losses = []
num_tasks = 4
def sample_tasks(outer_batch_size, inner_batch_size):
# Select amplitude and phase for the task
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
for it in range(20000):
loss2 = 0.0
opt.zero_grad()
def get_loss_for_task(x1, y1, x2, y2):
f = net(x1, params)
loss = F.mse_loss(f, y1)
# create_graph=True because computing grads here is part of the forward pass.
# We want to differentiate through the SGD update steps and get higher order
# derivatives in the backward pass.
grads = torch.autograd.grad(loss, params, create_graph=True)
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
v_f = net(x2, new_params)
return F.mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)]
loss2 = sum(inner_losses)/len(inner_losses)
loss2.backward()
opt.step()
if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
losses.append(loss2)
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
t_x = torch.empty(4, 1).uniform_(-5, 5)
t_y = t_A*torch.sin(t_x + t_b)
opt.zero_grad()
t_params = params
for k in range(5):
t_f = net(t_x, t_params)
t_loss = F.l1_loss(t_f, t_y)
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = t_A*torch.sin(test_x + t_b)
test_f = net(test_x, t_params)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()
plt.savefig('maml-sine.png')
plt.figure()
plt.plot(np.convolve(losses, [.05]*20))
plt.savefig('losses.png')

View File

@ -0,0 +1,118 @@
import math
import random
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from functorch import grad, vmap
def net(params, x):
x = F.linear(x, params[0], params[1])
x = F.relu(x)
x = F.linear(x, params[2], params[3])
x = F.relu(x)
x = F.linear(x, params[4], params[5])
return x
params = [
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
torch.Tensor(1).zero_().requires_grad_(),
]
# The prototype doesn't like F.mse_loss.
def mse_loss(x, y):
return torch.mean((x - y) ** 2)
opt = torch.optim.Adam(params, lr=1e-3)
alpha = 0.1
K = 20
losses = []
num_tasks = 4
def sample_tasks(outer_batch_size, inner_batch_size):
# Select amplitude and phase for the task
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
for it in range(20000):
loss2 = 0.0
opt.zero_grad()
def get_loss_for_task(x1, y1, x2, y2):
def inner_loss(params, x1, y1):
f = net(params, x1)
loss = mse_loss(f, y1)
return loss
grads = grad(inner_loss)(tuple(params), x1, y1)
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
v_f = net(new_params, x2)
return mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
loss2 = sum(inner_losses)/len(inner_losses)
loss2.backward()
opt.step()
if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
losses.append(loss2)
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
t_x = torch.empty(4, 1).uniform_(-5, 5)
t_y = t_A*torch.sin(t_x + t_b)
opt.zero_grad()
t_params = params
for k in range(5):
t_f = net(t_x, t_params)
t_loss = F.l1_loss(t_f, t_y)
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = t_A*torch.sin(test_x + t_b)
test_f = net(test_x, t_params)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()
plt.savefig('maml-sine.png')
plt.figure()
plt.plot(np.convolve(losses, [.05]*20))
plt.savefig('losses.png')

View File

@ -0,0 +1,115 @@
import math
import random
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from functorch import grad, vmap, make_functional
class ThreeLayerNet(nn.Module):
def __init__(self):
super(ThreeLayerNet, self).__init__()
self.fc1 = nn.Linear(1, 40)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(40, 40)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(40, 1)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
# The prototype doesn't like F.mse_loss.
def mse_loss(x, y):
return torch.mean((x - y) ** 2)
params, net, _ = make_functional(ThreeLayerNet())
opt = torch.optim.Adam(params, lr=1e-3)
alpha = 0.1
K = 20
losses = []
num_tasks = 4
def sample_tasks(outer_batch_size, inner_batch_size):
# Select amplitude and phase for the task
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
for it in range(20000):
loss2 = 0.0
opt.zero_grad()
def get_loss_for_task(x1, y1, x2, y2):
def inner_loss(params, x1, y1):
f = net(params, (x1,))
loss = mse_loss(f, y1)
return loss
grads = grad(inner_loss)(params, x1, y1)
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
v_f = net(new_params, (x2,))
return mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
loss2 = sum(inner_losses)/len(inner_losses)
loss2.backward()
opt.step()
if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
losses.append(loss2)
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
t_x = torch.empty(4, 1).uniform_(-5, 5)
t_y = t_A*torch.sin(t_x + t_b)
opt.zero_grad()
t_params = params
for k in range(5):
t_f = net(t_params, (t_x,))
t_loss = F.l1_loss(t_f, t_y)
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = t_A*torch.sin(test_x + t_b)
test_f = net(t_params, (test_x,))
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()
plt.savefig('maml-sine.png')
plt.figure()
plt.plot(np.convolve(losses, [.05]*20))
plt.savefig('losses.png')

View File

@ -0,0 +1,6 @@
import torch
from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers

View File

View File

@ -0,0 +1,185 @@
import torch
from functools import partial
import collections
import torch.nn as nn
import torch.nn.functional as F
from torch.make_functional import make_functional, make_functional_with_buffers
import gc
from .vmap import vmap
from functorch._C import (
_wrap_for_grad,
_unwrap_for_grad,
_grad_increment_nesting,
_grad_decrement_nesting,
)
# x = torch.ones(2, 3)
# y = torch.ones(2, 3)
# # result = vmap(torch.add)(x, y)
# result = vmap(vmap(torch.add))(x, y)
# assert torch.allclose(result, x + y)
# TODO: replace all of these with pytrees
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
aliased = tensor
return aliased.requires_grad_()
if isinstance(tensor_or_tuple_of_tensors, tuple):
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
if isinstance(tensor_or_tuple_of_tensors, list):
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
raise ValueError(f'Thing passed to transform API must be Tensor, List or Tuple, '
f'got {type(tensor_or_tuple_of_tensors)}')
def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
return _unwrap_for_grad(tensor, level)
if isinstance(tensor_or_tuple_of_tensors, tuple):
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
if isinstance(tensor_or_tuple_of_tensors, list):
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
assert False
def _any_differentiable(tensor_or_tuple_of_tensors):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
return tensor.requires_grad
if isinstance(tensor_or_tuple_of_tensors, tuple):
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
if isinstance(tensor_or_tuple_of_tensors, list):
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
return False
def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
return _wrap_for_grad(tensor, level)
if isinstance(tensor_or_tuple_of_tensors, tuple):
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
if isinstance(tensor_or_tuple_of_tensors, list):
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
return tensor_or_tuple_of_tensors
# How do we increment and decrement the nesting? I don't think we can.
def vjp(f, *primals):
level = _grad_increment_nesting()
try:
primals = _wrap_all_tensors(primals, level)
diff_primals = _create_differentiable(primals, level)
primals_out = f(*diff_primals)
results = _undo_create_differentiable(primals_out, level)
def wrapper(*cotangents, retain_graph=True, create_graph=True):
result = torch.autograd.grad(primals_out, diff_primals, cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return result
finally:
_grad_decrement_nesting()
return results, wrapper
def jacrev(f):
def wrapper_fn(primal):
output, vjp_fn = vjp(f, primal)
assert isinstance(output, torch.Tensor)
# TODO: does jacrev compose with vmap...? the eye call should make it so that it doesn't
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device) \
.view(output.numel(), *output.shape)
result, = vmap(vjp_fn)(basis)
result = result.view(*output.shape, *primal.shape)
return result
return wrapper_fn
#
#
# def jacrev(f, diff_argnums=(0,)):
# def wrapper(*args):
# torch._C._grad_increment_nesting()
# output = None
# grad_outputs = None
# try:
# args = [_create_differentiable(arg) if i in diff_argnums else arg
# for i, arg in enumerate(args)]
# output = f(*args)
# # Only support single tensor output for now
# assert isinstance(output, torch.Tensor)
# output_numel = output.numel()
# if output_numel != 0:
# grad_output = torch.eye(output_numel).view(output_numel, *output.shape)
#
# diff_args = [args[i] for i in diff_argnums]
# single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
# # TODO: quick hack...
# if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
# diff_args = diff_args[0]
# # NB: need create_graph so that backward pass isn't run in no_grad mode
#
# def compute_vjp(v):
# return torch.autograd.grad(output, diff_args, v, create_graph=True)
#
# if output_numel == 0:
# grad_input = compute_vjp(grad_output)
# else:
# grad_input = vmap(compute_vjp)(grad_output)
#
# if single_diff_arg:
# grad_input = grad_input[0]
# finally:
# _undo_create_differentiable(args)
# torch._C._grad_decrement_nesting()
# return grad_input, output
# return wrapper
def grad_with_value(f, diff_argnums=(0,), has_aux=False):
def wrapper(*args):
level = _grad_increment_nesting()
output, aux, grad_input = None, None, None
try:
args = _wrap_all_tensors(args, level)
args = [_create_differentiable(arg, level) if i in diff_argnums else arg
for i, arg in enumerate(args)]
# print("calling f(*args)")
output = f(*args)
# print("done with f(*args)")
if has_aux:
output, aux = output
# print("calling output.dim()")
assert output.dim() == 0
diff_args = [args[i] for i in diff_argnums]
single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
# TODO: quick hack...
if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
diff_args = diff_args[0]
# NB: need create_graph so that backward pass isn't run in no_grad mode
# import torchviz; import graphviz
# graph = torchviz.make_dot(output)
# graph.save("inner.dot")
# print("calling autograd.grad")
grad_input = torch.autograd.grad(
output, diff_args, create_graph=True)
# print("done-ish!")
if single_diff_arg:
grad_input = grad_input[0]
finally:
if grad_input is not None:
grad_input = _undo_create_differentiable(grad_input, level)
_grad_decrement_nesting()
if has_aux:
return grad_input, output, aux
return grad_input, output
return wrapper
def grad(f, diff_argnums=(0,), has_aux=False):
def wrapper(*args):
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
if has_aux:
return results[0], results[2]
return results[0]
return wrapper

View File

@ -0,0 +1,99 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
import copy
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p for p in orig_params)
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_set_nested_attr(mod, name.split("."), p)
def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
orig_params = tuple(mod.buffers())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_buffers()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p for p in orig_params)
return params, names
def load_buffers(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
for name, p in zip(names, params):
_set_nested_attr(mod, name.split("."), p)
def make_functional(model: nn.Module):
weights, descriptors = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors
def make_functional_with_buffers(model: nn.Module):
weights, weight_descriptors = extract_weights(model)
buffers, buf_descriptors = extract_buffers(model)
def fun(weights, buffers, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, weight_descriptors, weights)
load_buffers(mutable_model, buf_descriptors, buffers)
return mutable_model(*data)
return weights, buffers, fun, weight_descriptors, buf_descriptors

View File

@ -0,0 +1,270 @@
import torch
import functools
from torch import Tensor
from typing import Any, Callable, Optional, Tuple, Union, List
from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten
import warnings
from functorch._C import (
_add_batch_dim,
_remove_batch_dim,
_vmapmode_decrement_nesting,
_vmapmode_increment_nesting,
)
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]],
flat_args: List) -> int:
batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
if in_dim is not None]
if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]):
raise ValueError(
f'vmap: Expected all tensors to have the same size in the mapped '
f'dimension, got sizes {batch_sizes} for the mapped dimension')
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
# If value is a tuple, check it has length `num_elements`.
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
raise ValueError(error_message_lambda())
return value
# Creates BatchedTensors for every Tensor in arg that should be batched.
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'expected `in_dims` to be int or a (potentially nested) tuple '
f'matching the structure of inputs, got: {type(in_dims)}.')
if len(args) == 0:
raise ValueError(
f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
f'inputs, or you are trying to vmap over a function with no inputs. '
f'The latter is unsupported.')
flat_args, args_spec = tree_flatten(args)
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
if flat_in_dims is None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'in_dims is not compatible with the structure of `inputs`. '
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
f'has structure {args_spec}.')
for arg, in_dim in zip(flat_args, flat_in_dims):
if not isinstance(in_dim, int) and in_dim is not None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but in_dim must be either '
f'an integer dimension or None.')
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but the input is of type '
f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
f'please use None as the respective in_dim')
if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for some input, but that input is a Tensor '
f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
f'0 <= in_dim < {arg.dim()}.')
batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [arg if in_dim is None else
_add_batch_dim(arg, in_dim, vmap_level) # type: ignore
for in_dim, arg in zip(flat_in_dims, flat_args)]
return tree_unflatten(batched_inputs, args_spec), batch_size
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
num_outputs = _num_outputs(batched_outputs)
out_dims_as_tuple = _as_tuple(
out_dims, num_outputs,
lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must '
f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.')
# NOTE [Ignored _remove_batch_dim, _add_batch_dim]
# There is something wrong with our type bindings for functions that begin
# with '_', see #40397.
if isinstance(batched_outputs, Tensor):
out_dim = out_dims_as_tuple[0]
return _remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore
return tuple(_remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
# Checks that `fn` returned one or more Tensors and nothing else.
# NB: A python function that return multiple arguments returns a single tuple,
# so we are effectively checking that `outputs` is a single Tensor or a tuple of
# Tensors.
def _validate_outputs(outputs: Any, func: Callable) -> None:
if isinstance(outputs, Tensor):
return
if not isinstance(outputs, tuple):
raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
f'Tensors, got type {type(outputs)} as the return.')
for idx, output in enumerate(outputs):
if isinstance(output, Tensor):
continue
raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
f'Tensors, got type {type(output)} for return {idx}.')
def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
if isinstance(out_dims, int):
return
if not isinstance(out_dims, tuple) or \
not all([isinstance(out_dim, int) for out_dim in out_dims]):
raise ValueError(
f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
f'an int or a tuple of int representing where in the outputs the '
f'vmapped dimension should appear.')
def _get_name(func: Callable):
if hasattr(func, '__name__'):
return func.__name__
# Not all callables have __name__, in fact, only static functions/methods do.
# A callable created via functools.partial or an nn.Module, to name some
# examples, don't have a __name__.
return repr(func)
# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
# sends those into func, and then unwraps the output BatchedTensors. Operations
# on BatchedTensors perform the batched operations that the user is asking for.
def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
"""
vmap is the vectorizing map. Returns a new function that maps `func` over some
dimension of the inputs. Semantically, vmap pushes the map into PyTorch
operations called by `func`, effectively vectorizing those operations.
vmap is useful for handling batch dimensions: one can write a function `func`
that runs on examples and then lift it to a function that can take batches of
examples with `vmap(func)`. vmap can also be used to compute batched
gradients when composed with autograd.
.. warning::
functorch.vmap is an experimental prototype that is subject to
change and/or deletion. Please use at your own risk.
.. note::
If you're interested in using vmap for your use case, please
`contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
We're interested in gathering feedback from early adopters to inform
the design.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. `in_dims` should have a structure
like the inputs. If the `in_dim` for a particular input is None,
then that indicates there is no map dimension. Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If `out_dims` is a Tuple, then it should
have one element per output. Default: 0.
Returns:
Returns a new "batched" function. It takes the same inputs as `func`,
except each input has an extra dimension at the index specified by `in_dims`.
It takes returns the same outputs as `func`, except each output has
an extra dimension at the index specified by `out_dims`.
.. warning:
vmap works best with functional-style code. Please do not perform any
side-effects in `func`, with the exception of in-place PyTorch operations.
Examples of side-effects include mutating Python data structures and
assigning values to variables not captured in `func`.
One example of using `vmap` is to compute batched dot products. PyTorch
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
rummaging through docs, use `vmap` to construct a new function.
>>> torch.dot # [D], [D] -> []
>>> batched_dot = functorch.vmap(torch.dot) # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
model authoring experience.
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = functorch.vmap(model)(examples)
`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation.
The PyTorch autograd engine computes vjps (vector-Jacobian products).
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
we can vectorize the whole computation, computing the Jacobian in a single
call to `autograd.grad`.
>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>> for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = functorch.vmap(get_vjp)(I_N)
.. note::
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
"""
warnings.warn(
'functorch.vmap is an experimental prototype that is subject to '
'change and/or deletion. Please use at your own risk. There may be '
'unexpected performance cliffs due to certain operators not being '
'implemented. To see detailed performance warnings please use '
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
'before the call to `vmap`.',
stacklevel=2)
return _vmap(func, in_dims, out_dims)
# A version of vmap but without the initial "experimental prototype" warning
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
@functools.wraps(func)
def wrapped(*args):
_check_out_dims_is_int_or_int_tuple(out_dims, func)
vmap_level = _vmapmode_increment_nesting()
try:
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
batched_outputs = func(*batched_inputs)
_validate_outputs(batched_outputs, func)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
finally:
_vmapmode_decrement_nesting()
return wrapped

View File

@ -0,0 +1,416 @@
#include <functorch/csrc/BatchedFallback.h>
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/Constants.h>
#include <functorch/csrc/TensorWrapper.h>
#include <functorch/csrc/DynamicLayer.h>
#include <ATen/Context.h>
#include <ATen/MatrixRef.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/accumulate.h>
#include <c10/util/llvmMathExtras.h>
namespace at {
namespace functorch {
// Given a linear index, return the actual index.
// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
computeIndex(int64_t linear_idx, IntArrayRef sizes) {
at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
result.reserve(sizes.size());
for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
auto remainder = linear_idx % *it;
result.push_back(remainder);
linear_idx -= remainder;
linear_idx /= *it;
}
std::reverse(std::begin(result), std::end(result));
return result;
}
static bool areAllReturnsTensors(const at::FunctionSchema& schema) {
return std::all_of(
schema.returns().begin(),
schema.returns().end(),
[] (const Argument& arg) { return arg.type() == TensorType::get(); });
}
static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
return std::any_of(
schema.arguments().begin(),
schema.arguments().end(),
[] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
}
// Returns if an operator is in-place. An operator is inplace if:
// 1. The first argument is a Tensor and it is being written to
// 2. The first argument is being returned
// 3. No other arguments are aliased
// Here is an example of an in-place operator:
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
static bool isInplaceOp(const FunctionSchema& schema) {
if (!schema.is_mutable() || schema.returns().size() != 1) {
return false;
}
// Check that the first argument is being written to
const auto& first_arg_alias_info = schema.arguments().begin()->alias_info();
if (!first_arg_alias_info || !first_arg_alias_info.value().isWrite()) {
return false;
}
// Check that none of the other args are being aliased
for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
const auto& alias_info = it->alias_info();
if (alias_info) {
return false;
}
}
// Check that the first tensor is being returned (i.e., output has a (a!))
const auto& return_alias_info = schema.returns()[0].alias_info();
return return_alias_info && return_alias_info.value().isWrite();
}
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
if (!globalContext().areVmapFallbackWarningsEnabled()) {
return;
}
auto uses_stack = is_inplace ? "" : " and stack";
TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back "
"to slow (for loop", uses_stack, ") implementation");
}
// The general flow of the algorithm is as follows.
// - First, we figure out which arguments are BatchedTensors and save them
// to a vector. We also store a vector of which index of the arguments list
// each BatchedTensor appears in. This will be useful for bookkeeping later.
// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
// This returns a vector of VmapPhysicalView that hold tensors that contain
// all of the collective batch dimensions at the front of the tensors.
// - Then, we attempt to call `op` once per slice of the inputs. To do this,
// we repeatedly we slice the input arguments (if they are BatchedTensors),
// put the sliced (or a not-sliced) version of the input onto the stack, invoke
// the operator, and then pop the results off the stack.
void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
warnFallback(schema, /*in_place*/true);
const auto num_arguments = schema.arguments().size();
const auto arguments = torch::jit::last(stack, num_arguments);
const auto arguments_begin = stack->size() - num_arguments;
// `self` is the Tensor being modified in-place
Tensor self = arguments[0].toTensor();
const auto* self_impl = maybeGetBatchedImpl(self);
std::bitset<kVmapMaxTensorDims> self_vmap_levels;
if (self_impl) {
self_vmap_levels = createVmapLevelsBitset(self_impl->bdims());
}
// Figure out which arguments are BatchedTensor. Save them to a vector.
// For each BatchedTensor, also record what position of `arguments` they came from.
at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
VmapDimVector batched_tensor_inputs_position;
for (int64_t idx = 0; idx < arguments.size(); ++idx) {
const auto& ivalue = arguments[idx];
if (!ivalue.isTensor()) {
continue;
}
const auto& tensor = ivalue.toTensor();
if (!tensor.defined()) {
continue;
}
const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
continue;
}
// NOTE: [vmap-incompatible in-place operations]
// In-place operations on `self` are not possible if there exists some vmap
// level `l` such that `self` is not being vmapped on that level but another
// argument is. For example, let B0 be a batch dim inside vmap and consider
// vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
// - self is torch.ones(3) and does not participate in this vmap
// - other is BatchedTensor(torch.ones(B0, 3))
// There's no way to do self.add_(other) because `other` has more elements
// elements than `self` due to being vmapped over.
//
// In the vmap fallback, we should error out when we detect this.
auto other_vmap_levels = createVmapLevelsBitset(batched->bdims());
if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
// Find one vmap level to complain about
auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
// The following prints out "vmap: aten::add_(tensor, ...) is not possible",
// but it would be better to print out "tensor.add_(...) is not possible".
// Afaict there's no official way to get the add_ and there is no way to
// tell if an operator has method or function variants.
TORCH_CHECK(false,
"vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
"there exists a Tensor `other` in extra_args that has more elements ",
"than `self`. This happened due to `other` being vmapped over but ",
"`self` not being vmapped over at level ", offending_level, ". ",
"Please try to use out-of-place operators instead of ", schema.name(), ". ",
"If said operator is being called inside the PyTorch framework, ",
"please file a bug report instead.");
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
}
TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
// VmapPhysicalViews that contain all of the batch dimensions.
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
batched_tensor_inputs);
// Compute the total number of batches
auto num_batch_dims = input_physical_views.front().numBatchDims();
auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(
first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
const auto num_batches = c10::multiply_integers(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support vmap over dims of size 0.");
// Strategy: For each batch, we are going to push slices (where applicable)
// of the arguments onto `stack`, and call `op`.
for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
auto index = computeIndex(linear_idx, batch_sizes);
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
auto input_physical_views_iter = input_physical_views.begin();
for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) {
// We assume that torch::jit::Stack is backed by vector<IValue> for
// simplicity. When that is not the case, this code should be updated.
const auto& argument = (*stack)[arguments_begin + arg_idx];
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|| arg_idx != *batched_tensor_inputs_pos_iter) {
// argument isn't a BatchedTensor
torch::jit::push(stack, argument);
continue;
}
// argument is a BatchedTensor
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
const auto& physical_view_for_argument = *input_physical_views_iter;
auto thing = physical_view_for_argument.tensor().index(index);
torch::jit::push(stack, thing);
batched_tensor_inputs_pos_iter++;
input_physical_views_iter++;
}
op.callBoxed(stack);
torch::jit::drop(stack, 1);
}
// Return the tensor that was written to in-place
torch::jit::drop(stack, num_arguments);
torch::jit::push(stack, self);
}
static Tensor safeStack(TensorList tensors) {
auto is_defined = [](const Tensor& t) { return t.defined(); };
if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
return at::stack(tensors);
}
// NOTE [vmap through backward and undefined grad]
// While vmapping through backward functions (to compute batched grad), it
// is possible for the backward function to return an undefined grad for some
// grad_input for each example. In that case, we return an undefined grad.
//
// It is theoretically posssible for *some* of the examples to produce an
// undefined grad (a kernel could peek at the gradient values and return an
// undefined tensor if it determines the gradient is full of zeros). We
// could handle this by treating the undefined grad as a zero-filled tensor
// of the correct shape while stacking the tensors together. However I expect
// this to happen very rarely (I have not been able to find an example in our
// codebase) so we just error out in this case.
if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
return Tensor();
}
TORCH_CHECK(false,
"vmap: slow fallback received a mix of undefined and defined tensors ",
"as the result of an operation. This is not supported, please file us ",
"an issue on github.");
}
// TODO: dedup
static bool participatesInCurrentLevel(const Tensor& self) {
auto maybe_level = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_level.has_value());
auto current_level = maybe_level->layerId();
auto* maybe_batched_impl = maybeGetBatchedImpl(self);
if (!maybe_batched_impl) {
return false;
}
const auto& bdims = maybe_batched_impl->bdims();
TORCH_INTERNAL_ASSERT(bdims.size() == 1);
auto self_level = bdims.back().level();
TORCH_INTERNAL_ASSERT(self_level <= current_level);
return self_level == current_level;
}
static bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
if (!ivalue.isTensor()) {
return false;
}
return participatesInCurrentLevel(ivalue.toTensor());
}
// The general flow of the algorithm is as follows.
// - First, we figure out which arguments are BatchedTensors and save them
// to a vector. We also store a vector of which index of the arguments list
// each BatchedTensor appears in. This will be useful for bookkeeping later.
// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
// This returns a vector of VmapPhysicalView that hold tensors that contain
// all of the collective batch dimensions at the front of the tensors.
// - Then, we attempt to call `op` once per slice of the inputs. To do this,
// we repeatedly we slice the input arguments (if they are BatchedTensors),
// put the sliced (or a not-sliced) version of the input onto the stack, invoke
// the operator, and then pop the results off the stack.
// - Each result obtained from the previous step is a slice of the total result,
// so we stack those tensors together to form the final result.
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
const auto arguments = torch::jit::last(stack, num_arguments);
TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
"Batching rule not implemented for ", schema.operator_name(), ". ",
"We could not generate a fallback.");
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
op.callBoxed(stack);
return;
}
if (isInplaceOp(schema)) {
batchedTensorInplaceForLoopFallback(op, stack);
return;
}
TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
"Batching rule not implemented for ", schema.operator_name(), "; ",
"the fallback path doesn't work on out= or view ops.");
TORCH_CHECK(num_returns >= 1,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support operations with no returns.");
warnFallback(schema, /*in_place*/false);
const auto arguments_begin = stack->size() - num_arguments;
// Figure out which arguments are BatchedTensor. Save them to a vector.
// For each BatchedTensor, also record what position of `arguments` they came from.
at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
VmapDimVector batched_tensor_inputs_position;
for (int64_t idx = 0; idx < arguments.size(); ++idx) {
const auto& ivalue = arguments[idx];
if (!ivalue.isTensor()) {
continue;
}
const auto& tensor = ivalue.toTensor();
if (!tensor.defined()) {
continue;
}
const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
continue;
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
}
TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
// VmapPhysicalViews that contain all of the batch dimensions.
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
batched_tensor_inputs);
// Compute the total number of batches
auto num_batch_dims = input_physical_views.front().numBatchDims();
auto some_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
const auto num_batches = c10::multiply_integers(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support vmap over dims of size 0.");
// Strategy: For each batch, we are going to push slices (where applicable)
// of the arguments onto `stack`, call `op`, and store the result in
// `output_shards`.
//
// NOTE: [Output shards layout]
// Assume that the operator has three outputs: a, b, c.
// The layout of output_shards is as follows:
// [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
// This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
// more easily in the next step.
std::vector<Tensor> output_shards(num_batches * num_returns);
for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
auto index = computeIndex(linear_idx, batch_sizes);
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
auto input_physical_views_iter = input_physical_views.begin();
for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) {
// We assume that torch::jit::Stack is backed by vector<IValue> for
// simplicity. When that is not the case, this code should be updated.
const auto& argument = (*stack)[arguments_begin + arg_idx];
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|| arg_idx != *batched_tensor_inputs_pos_iter) {
// argument isn't a BatchedTensor
torch::jit::push(stack, argument);
continue;
}
// argument is a BatchedTensor
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
const auto& physical_view_for_argument = *input_physical_views_iter;
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
batched_tensor_inputs_pos_iter++;
input_physical_views_iter++;
}
// std::cout << "[Fallback]: ";
// at::dump_tensor((*stack)[stack->size() - 1].toTensor());
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
op.callBoxed(stack);
// Store the result into `output_shards`. See NOTE: [Output shards layout]
// to learn about the details of how we store the shards.
const auto returns = torch::jit::last(stack, num_returns);
for (int64_t return_idx = 0; return_idx < returns.size(); ++return_idx) {
output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
}
torch::jit::drop(stack, num_returns);
}
// For each output Tensor, stack the shards of the tensor together to form a return
torch::jit::drop(stack, num_arguments);
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
auto shards = output_shards_chunks[return_idx];
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto flat_output = safeStack(shards);
// See NOTE [vmap through backward and undefined grad]
if (!flat_output.defined()) {
torch::jit::push(stack, flat_output);
continue;
}
VmapDimVector output_sizes(batch_sizes);
output_sizes.insert(
output_sizes.end(),
flat_output.sizes().begin() + 1,
flat_output.sizes().end());
torch::jit::push(
stack,
input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
}
}
}
} // namespace at

View File

@ -0,0 +1,26 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
namespace at {
namespace functorch {
// If an operator doesn't have a batching rule implemented then we fallback
// to this implementation. The fallback only works on out-of-place operators
// that return only tensors with new memory. (e.g., no in-place operators, no
// view operations).
//
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
// them, and runs `op` on all of the corresponding slices to produce slices
// of the outputs. The output slices then get `torch.stack`ed to create the
// final returns.
//
// The performance of the fallback is not very good because it introduces an
// extra copy from stacking the sliced outputs. Because of this, we prefer to
// write batching rules for operators whenever possible.
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
}
} // namespace at

View File

@ -0,0 +1,169 @@
#include <functorch/csrc/BatchedTensorImpl.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Exception.h>
#include <functorch/csrc/Constants.h>
namespace at {
namespace functorch {
BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
: TensorImpl(
c10::DispatchKeySet(kBatchedKey),
value.dtype(),
value.device()
)
, value_(std::move(value))
, bdims_(std::move(bdims))
{
TORCH_INTERNAL_ASSERT(value_.defined());
set_storage_access_should_throw();
checkInvariants();
const auto public_dims = value_.dim() - bdims_.size();
const auto value_sizes = value_.sizes();
const auto value_strides = value_.strides();
sizes_and_strides_.resize(public_dims);
for (int64_t dim = 0; dim < public_dims; dim++) {
auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
}
refresh_numel();
refresh_contiguous();
}
BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, BatchDims bdims)
: TensorImpl(
key_set.add(kBatchedKey),
value.dtype(),
value.device()
)
, value_(std::move(value))
, bdims_(std::move(bdims))
{
TORCH_INTERNAL_ASSERT(value_.defined());
checkInvariants();
TORCH_INTERNAL_ASSERT(bdims_.size() == 1);
refreshSizesAndStrides();
}
void BatchedTensorImpl::refreshSizesAndStrides() {
const auto public_dims = value_.dim() - bdims_.size();
const auto value_sizes = value_.sizes();
const auto value_strides = value_.strides();
sizes_and_strides_.resize(public_dims);
for (int64_t dim = 0; dim < public_dims; dim++) {
auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
}
refresh_numel();
refresh_contiguous();
}
int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
if (wrap_dim) {
const auto ndim = sizes_and_strides_.size();
dim = maybe_wrap_dim(dim, ndim);
}
auto is_bdim = createBatchDimBitset(bdims_);
// Example: assume dim = 3, and is_bdim = 10010011000...
// The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor.
// actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent
// to asking "where does the 3rd (0-indexed) zero occur in the bitset?".
// The answer to that is index 5.
//
// TODO(rzou): the PDEP instruction does exactly this
// (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int)
// but it might require newer (>= ~2015) CPUs. We should clean this up
// if/when we have dropped support for older CPUs.
int64_t non_bdim_count = 0;
for (int64_t actual_dim = 0; actual_dim < kVmapMaxTensorDims; actual_dim++) {
if (is_bdim[actual_dim]) {
continue;
}
if (non_bdim_count == dim) {
return actual_dim;
}
non_bdim_count++;
}
// If we hit this assert, then that means
// `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number
// of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should
// never be hit.
TORCH_INTERNAL_ASSERT(false);
}
void BatchedTensorImpl::checkInvariants() const {
int64_t prev_level = -1;
for (const auto& bdim : bdims_) {
TORCH_INTERNAL_ASSERT(bdim.level() > prev_level);
prev_level = bdim.level();
}
}
// The following are publically exposed as methods of Tensor
bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
"NYI: querying is_contiguous inside of vmap for memory_format ",
"other than torch.contiguous_format");
return is_contiguous_;
}
// The following are some internal inherited methods that we do not support.
// They should never get called.
void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) {
TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl");
}
void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl");
}
void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) {
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl");
}
#ifdef DEBUG
bool BatchedTensorImpl::has_storage() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set");
return false;
}
#endif
const char* BatchedTensorImpl::tensorimpl_type_name() const {
return "BatchedTensorImpl";
}
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
DispatchKeySet key_set;
if (tensor.is_cuda()) {
key_set = key_set.add(DispatchKey::CUDA);
}
return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, std::move(bdims));
}
Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
BatchDims new_bdims = { { level, dim } };
TORCH_INTERNAL_ASSERT(new_bdims.size() == 1);
return makeBatched(tensor, std::move(new_bdims));
}
bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) {
const auto* other_batched = maybeGetBatchedImpl(other);
if (!other_batched) {
return true;
}
const auto* self_batched = maybeGetBatchedImpl(self);
if (!self_batched) {
// self is not batched but other is batched
return false;
}
auto self_levels = createVmapLevelsBitset(self_batched->bdims());
auto other_levels = createVmapLevelsBitset(other_batched->bdims());
return self_levels == (self_levels | other_levels);
}
}
} // namespace at

View File

@ -0,0 +1,158 @@
#pragma once
#include <bitset>
#include <ATen/ArrayRef.h>
#include <ATen/SmallVector.h>
#include <ATen/Tensor.h>
#include <functorch/csrc/Constants.h>
namespace at {
namespace functorch {
using Tensor = at::Tensor;
// We assume this in a few other places in the codebase,
// but there isn't a centralized definition.
constexpr int64_t kVmapMaxTensorDims = 64;
// The valid vmap levels range from [0, 64). This effectively means that we
// support a maximum of 64 nested vmaps.
constexpr int64_t kVmapNumLevels = 64;
// Store this number of elements of BatchDims on the stack. Most people will
// probably use <= 5 nested vmaps, but adjust this number as necessary.
constexpr int64_t kBatchDimsStackSize = 5;
// a BatchDim represents a "private" dimension on a Tensor created inside of
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
// is being vmap'ed over and the `level` being an identifier for which vmap
// said dimension was created inside. The `dim` corresponds to a "physical
// dim" - it is a dimension index on the underlying physical tensor that is being
// vmapped over.
struct BatchDim {
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
int64_t dim() const {
return dim_;
}
int64_t level() const {
return level_;
}
private:
int64_t dim_;
int64_t level_;
};
using BatchDims = at::SmallVector<BatchDim, kBatchDimsStackSize>;
using BatchDimsRef = at::ArrayRef<BatchDim>;
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
//
// The batch dimensions are treated as being "private"; they are not user-visible.
// For example, in the following Tensor,
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
// dimensions 0 and 1 are batch dimensions.
//
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
struct BatchedTensorImpl : public c10::TensorImpl {
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, BatchDims bdims);
// Returns a reference to BatchDims that represent which dimensions of this
// tensor are private.
BatchDimsRef bdims() const { return bdims_; }
// BatchedTensorImpl wraps a Tensor
const Tensor& value() const { return value_; };
// Given a public dimension index, return the dimension index in the underlying
// value() tensor.
// For example, if we have
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=2)])
// bt.actualDim(0) -> 1
// bt.actualDim(1) -> 3
// bt.actualDim(2) -> Error
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
// Override a bunch of methods inherited from TensorImpl to return error messages.
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
#ifdef DEBUG
bool has_storage() const override;
#endif
void refreshSizesAndStrides();
private:
// see NOTE: [BatchedTensorImpl levels invariant]
void checkInvariants() const;
const char* tensorimpl_type_name() const override;
Tensor value_;
// Note: [BatchedTensorImpl levels invariant]
// There is an invariant that the BatchDims must be stored in increasing `level`
// order. That is, for i < j, bdims_[i].level must be less than bdims_[j].level.
BatchDims bdims_;
};
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
inline bool isBatchedTensor(const Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(kBatchedKey);
}
// It is unsafe to call this on a Tensor that is not backed by a
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
if (!isBatchedTensor(tensor)) {
return nullptr;
}
return unsafeGetBatchedImpl(tensor);
}
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(BatchDimsRef bdims) {
std::bitset<kVmapMaxTensorDims> is_bdim;
for (const auto& bdim : bdims) {
is_bdim.set(bdim.dim());
}
return is_bdim;
}
// Creates a bitset for all of the levels present in `bdims`
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
std::bitset<kVmapNumLevels> result;
for (const auto& bdim : bdims) {
result.set(bdim.level());
}
return result;
}
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
return out;
}
// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
// Checks if an inplace operation on self and other is "vmap compatible".
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
}
}

View File

@ -0,0 +1,155 @@
#pragma once
#include <ATen/Tensor.h>
namespace at {
namespace functorch {
// Metaprogramming things
template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
template <typename T> class debug_t;
// tail operation
template<class TypeList>
struct tail final {
static_assert(c10::guts::false_t<TypeList>::value,
"In typelist::tail<T>, the T argument must be typelist<...>.");
};
template<class Head, class... Tail>
struct tail<typelist<Head, Tail...>> final {
using type = typelist<Tail...>;
};
template<class TypeList> using tail_t = typename tail<TypeList>::type;
template <class First, class Second, class Next, class Tail>
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
using type = Next;
};
template <class Next, class Tail>
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, optional<int64_t>, Next, Tail> {
using type = Tail;
};
template <class Next, class Tail>
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, optional<int64_t>, Next, Tail> {
using type = Tail;
};
template <class Next, class Tail>
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, optional<int64_t>, Next, Tail> {
using type = Tail;
};
template <class TypeList> struct RemoveBatchDimAfterTensor {
using first = head_t<TypeList>;
using next = tail_t<TypeList>;
using second = head_t<next>;
using tail = tail_t<next>;
using type = concat_t<
typelist<first>,
typename RemoveBatchDimAfterTensor<
typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
>::type
>;
};
template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
using type = typelist<Type>;
};
template <> struct RemoveBatchDimAfterTensor<typelist<>> {
using type = typelist<>;
};
template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
// TODO: get rid of these
// Do I need templates on templates now?
// template <typename func_t> struct LowerToNextLayer {};
// template <typename Return, typename... Args> struct LowerToNextLayer<Return(Args...)> {
// // How to pass in batch_rule directly?
// static Return apply(Args... args);
// };
template <typename batch_rule_t, typename Result, typename... Args>
Result lowerToNextLayer(batch_rule_t batch_rule, Args... args);
//# Tensor lowerToNextLayer(
//# std::function<std::tuple<Tensor,optional<int64_t>>(const Tensor&, optional<int64_t>)> batch_rule,
//# const Tensor& tensor);
std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim);
template<typename F, F Func, typename Return, typename TupleArgs> struct TORCH_API Dummy {};
template<typename F, F Func, typename Return, typename...T> struct Dummy<F, Func, Return, std::tuple<T...>> {
static Return apply(T... args) {
return lowerToNextLayer(abs_batch_rule, std::forward<T>(args)...);
}
};
template <typename T> struct UnpackSingleItemTuple {
using type = T;
};
template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
using type = T;
};
template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
using type = Return(Args...);
};
template <typename Return, typename TL>
struct BuildFunction {
using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
};
template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
// std::tuple<Tensor,optional<int64_t>> (*kAbsBatchRule)(const Tensor& Tensor, optional<int64_t>)
// = &abs_batch_rule;
template <typename batch_rule_t> struct ToOperatorType {
using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
using operator_return_type =
unpack_single_item_tuple_t<
c10::guts::typelist::to_tuple_t<
remove_batch_dim_after_tensor_t<
c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
using type = build_function_t<operator_return_type, operator_parameter_types>;
};
template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
template <typename F, F Func> struct TORCH_API PrimBatchRule3 {
using func_t = to_operator_t<typename std::remove_pointer<F>::type>;
using result_type = typename c10::guts::function_traits<func_t>::return_type;
using parameter_types = c10::guts::typelist::to_tuple_t<typename c10::guts::function_traits<func_t>::parameter_types>;
static auto apply = Dummy<F, Func, result_type, parameter_types>::apply;
};
template<typename Return, typename TypeList> struct TORCH_API PrimBatchRule5 {};
template<typename Return, typename... T> struct PrimBatchRule5<Return, typelist<T...>> {
static inline Return apply(T... args) {
return lowerToNextLayer(abs_batch_rule, std::forward<T>(args)...);
}
};
template<typename func_t> struct PrimBatchRule6 {};
template<typename Return, typename... Args> struct PrimBatchRule6<Return (Args...)> {
static inline Return apply(Args... args) {
return lowerToNextLayer(abs_batch_rule, std::forward<Args>(args)...);
}
};
// template<typename batch_rule_t, batch_rule_t BatchRule> struct PrimBatchRule7 {};
// template<typename batch_rule_t, batch_rule_t BatchRule, typename BRReturn, typename... BRArgs>
// struct PrimBatchRule7<BRReturn(*)(BRArgs...), BatchRule> {
template<typename br_t, br_t BatchRule, typename func_t> struct PrimBatchRule7 {};
template<typename br_t, br_t BatchRule, typename Return, typename... Args> struct PrimBatchRule7<
br_t, BatchRule, Return (Args...)> {
static inline Return apply(Args... args) {
return lowerToNextLayer<br_t, Return, Args...>(BatchRule, std::forward<Args>(args)...);
}
};
}
} // namespace at

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,9 @@
#pragma once
#include <c10/core/DispatchKey.h>
namespace at {
namespace functorch {
constexpr auto kBatchedKey = c10::DispatchKey::BatchedOutOfTree;
}} // namespace at::functorch

View File

@ -0,0 +1,377 @@
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/TensorWrapper.h>
#include <torch/library.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/autograd/variable.h>
#include <c10/util/ThreadLocalDebugInfo.h>
namespace at {
namespace functorch {
// Initial autograd layer, because autograd is always "on"
// thread_local std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
DynmetaData kDynMetaDataSingleton;
static DynmetaData& getGlobalDynmetaData() {
return kDynMetaDataSingleton;
}
class DynamicLayerStackHolder : public c10::DebugInfoBase {
public:
DynamicLayerStackHolder() {}
virtual ~DynamicLayerStackHolder() {}
std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
};
thread_local std::shared_ptr<DynamicLayerStackHolder> kDynamicLayerStack;
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
if (kDynamicLayerStack == nullptr) {
kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
c10::ThreadLocalDebugInfo::_push(
// TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
c10::DebugInfoKind::PRODUCER_INFO,
kDynamicLayerStack);
}
TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
// TODO: can figure out how to memoize this. std::call_once with thread_local?
return kDynamicLayerStack->dynamicLayerStack;
}
std::shared_ptr<bool> getLifeHandleForLevel(int64_t level) {
auto it = getGlobalDynmetaData().find(level);
TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive");
return it->second;
}
optional<DynamicLayer> maybeCurrentDynamicLayer() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
// NB: Exception for regular autograd, maybe tweak this
if (dynamicLayerStack.size() <= 1) {
return {};
}
return dynamicLayerStack.back();
}
const std::vector<DynamicLayer>& getDynamicLayerStack() {
return dynamicLayerStackAccessor();
}
void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
dynamicLayerStackAccessor() = stack;
}
static DynamicLayer popDynamicLayer() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
auto result = dynamicLayerStack.back();
TORCH_INTERNAL_ASSERT(result.key() != DispatchKey::Undefined);
dynamicLayerStack.pop_back();
if (dynamicLayerStack.size() == 0) {
// std::cout << "DynamicLayer off" << std::endl;
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, false);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, false);
}
return result;
}
static int64_t pushDynamicLayer(DispatchKey key) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(key != DispatchKey::Undefined);
TORCH_INTERNAL_ASSERT(key != DispatchKey::Batched);
auto layerId = 1 + dynamicLayerStack.size();
dynamicLayerStack.emplace_back(key, layerId);
if (layerId == 2) {
// std::cout << "DynamicLayer on" << std::endl;
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
}
return layerId;
}
int64_t initAndPushDynamicLayer(DispatchKey key) {
auto layerId = pushDynamicLayer(key);
auto& data = getGlobalDynmetaData();
TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
data[layerId] = std::make_shared<bool>(true);
return layerId;
}
DynamicLayer popDynamicLayerAndDeleteMetadata() {
auto result = popDynamicLayer();
auto level = result.layerId();
// TODO: is this lock safe? No one else should be writing to the same bucket
if (c10::show_dispatch_trace_enabled()) {
std::cout << "deleting metadata" << std::endl;
}
auto& data = getGlobalDynmetaData();
auto it = data.find(level);
if (it == data.end()) {
return result;
}
if (c10::show_dispatch_trace_enabled()) {
std::cout << "deleted metadata for level " << level << std::endl;
}
// invalidate the thing
*(it->second) = false;
data.erase(level);
return result;
}
static Tensor materializeGradWrappers(const Tensor& tensor, const std::vector<DynamicLayer>& dynlayerStack) {
if (!tensor.defined()) {
return tensor;
}
// TODO: First entry in the stack is a default autograd key.
// We should clean up the logic
if (dynlayerStack.size() <= 1) {
return tensor;
}
if (dynlayerStack.back().key() != DispatchKey::Autograd) {
return tensor;
}
auto cur_level = dynlayerStack.back().layerId();
auto* wrapper = maybeGetTensorWrapper(tensor);
if (!wrapper) {
return makeTensorWrapper(tensor, cur_level);
}
TORCH_INTERNAL_ASSERT(wrapper->level().value() <= cur_level, "escaped?");
if (wrapper->level().value() == cur_level) {
TORCH_INTERNAL_ASSERT(tensor.defined());
return tensor;
}
return makeTensorWrapper(tensor, cur_level);
}
static Tensor unwrapIfDead(const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
return tensor;
}
if (wrapped->is_alive()) {
return tensor;
}
return wrapped->value();
}
static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
std::function<Tensor(const Tensor&)> func) {
TORCH_INTERNAL_ASSERT(begin >= 0);
TORCH_INTERNAL_ASSERT(end >= 0);
TORCH_INTERNAL_ASSERT(begin <= end);
for (int64_t idx = begin; idx < end; idx++) {
auto ivalue = args[idx];
if (ivalue.isTensorList()) {
auto list = ivalue.toTensorList();
for (int64_t list_idx = 0; list_idx < list.size(); list_idx++) {
list[list_idx] = func(list[list_idx]);
}
args[idx] = list;
}
TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
if (!ivalue.isTensor()) {
continue;
}
Tensor value = ivalue.toTensor();
Tensor replacement = func(value);
args[idx] = std::move(replacement);
// sanity checks
if (ivalue.toTensor().defined()) {
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
}
}
}
constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
DispatchKey::DynamicLayerFront,
DispatchKey::DynamicLayerBack,
DispatchKey::TensorWrapper,
// DispatchKey::Batched,
DispatchKey::BatchedOutOfTree,
DispatchKey::InplaceOrView
}) | autograd_dispatch_keyset;
static void sanityCheckStack(torch::jit::Stack* stack) {
if (stack->size() > 0) {
auto last_ivalue = (*stack)[stack->size() - 1];
if (last_ivalue.isTensor()) {
auto tensor = last_ivalue.toTensor();
auto* wrapper = maybeGetTensorWrapper(tensor);
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
TORCH_INTERNAL_ASSERT(tensor.has_storage());
}
}
}
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
if (c10::show_dispatch_trace_enabled()) {
std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
}
if (dynamicLayerStack.size() == 0) {
sanityCheckStack(stack);
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
op.callBoxed(stack);
return;
}
// Unwrap dead GradWrappers, materialize live ones
auto maybeTransformGradWrappers = [](const Tensor& tensor) {
auto result = unwrapIfDead(tensor);
return materializeGradWrappers(result, getDynamicLayerStack());
};
auto num_args = op.schema().arguments().size();
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
auto layer = dynamicLayerStack.back();
DispatchKeySet exclude = DispatchKeySet::FULL;
exclude = exclude.remove(DispatchKey::DynamicLayerBack);
if (layer.key() == DispatchKey::Autograd) {
exclude = exclude - autograd_dispatch_keyset;
exclude = exclude.remove(DispatchKey::InplaceOrView);
// } else if (layer.key() == DispatchKey::Batched) {
// exclude = exclude.remove(DispatchKey::Batched);
} else if (layer.key() == DispatchKey::BatchedOutOfTree) {
exclude = exclude.remove(DispatchKey::BatchedOutOfTree);
} else {
TORCH_INTERNAL_ASSERT(false);
}
c10::impl::ExcludeDispatchKeyGuard guard(exclude);
// Re-dispatch
op.callBoxed(stack);
}
struct WithoutTop {
WithoutTop(): layer_(popDynamicLayer()) {
}
~WithoutTop() {
pushDynamicLayer(layer_.key());
}
bool prev_grad_enabled_;
DynamicLayer layer_;
};
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto cur_level = getDynamicLayerStack().back().layerId();
auto cur_key = getDynamicLayerStack().back().key();
auto unwrap = [&](const Tensor& tensor) {
if (!tensor.defined()) {
return tensor;
}
auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
if (!maybe_tensor_wrapper) {
return tensor;
}
if (maybe_tensor_wrapper->level().value() == cur_level) {
return maybe_tensor_wrapper->value();
}
if (c10::show_dispatch_trace_enabled()) {
std::cout << "unwrap " << cur_level << std::endl;
}
return tensor;
};
auto wrap = [&](const Tensor& tensor) {
if (!tensor.defined()) {
return tensor;
}
if (cur_level == 1) {
return tensor;
}
if (c10::show_dispatch_trace_enabled()) {
std::cout << "wrap " << cur_level << std::endl;
}
return makeTensorWrapper(tensor, cur_level);
};
// TODO: we only need to do the following (marked with !) on in-place functions
// that modify sizes or strides. There aren't many of them.
// If autograd dispatch key:
// 1. (!) Put a copy of all of the args onto the stack
// 2. Unwrap all the args in the copy set
// 3. Call the operator
// 4. Wrap the output
// 5. (!) refreshSizesAndStrides for all the args in the original set
// 6. (!) Pop those args off.
// Step 1 & 2
if (cur_key == DispatchKey::Autograd) {
auto args_size = op.schema().arguments().size();
// Step 1
auto front = stack->size() - args_size;
for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
stack->push_back((*stack)[front + arg_idx]);
}
// Step 2
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
}
// pop the top layer. Put it back on dtor.
WithoutTop guard;
// "reset exclude set"
// TODO: Still a problem with composabiilty and AutoNonVariableTypeGuard.
// Users cannot do torch.no_grad otherwise there will be problems.
auto keyset = c10::impl::PODLocalDispatchKeySet();
c10::impl::_force_tls_local_dispatch_key_set(keyset);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
// Re-dispatch
op.callBoxed(stack);
// Step 4, 5, 6
if (cur_key == DispatchKey::Autograd) {
// Step 4
auto ret_size = op.schema().returns().size();
foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
// Step 5
auto args_size = op.schema().arguments().size();
auto args_front = stack->size() - args_size - ret_size;
for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
auto& ivalue = (*stack)[args_front + arg_idx];
if (!ivalue.isTensor()) {
continue;
}
auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
if (!maybe_tensor_wrapper) {
continue;
}
maybe_tensor_wrapper->refreshSizesAndStrides();
}
// Step 6
stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
}
}
TORCH_LIBRARY_IMPL(_, DynamicLayerFront, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
}
TORCH_LIBRARY_IMPL(_, DynamicLayerBack, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}
// TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
// m.impl("_unwrap_for_grad", native::_unwrap_for_grad);
// m.impl("dump_tensor", native::dump_tensor);
// m.impl("dlevel", native::dlevel);
// }
}
} // namespace at

View File

@ -0,0 +1,33 @@
#pragma once
#include <c10/core/DispatchKey.h>
#include <c10/util/Optional.h>
#include <unordered_map>
#include <mutex>
// Forward declared bc I am lazy
namespace c10 { struct AutogradMetaInterface; }
namespace at {
namespace functorch {
struct TORCH_API DynamicLayer {
DynamicLayer(DispatchKey key, int64_t layerId): key_(key), layerId_(layerId) {}
DispatchKey key() const { return key_; }
int64_t layerId() const { return layerId_; }
private:
DispatchKey key_;
int64_t layerId_;
};
TORCH_API int64_t initAndPushDynamicLayer(DispatchKey key);
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer();
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
// NB: not lock safe. TODO: does it need a lock?
TORCH_API std::shared_ptr<bool> getLifeHandleForLevel(int64_t level);
}
} // namespace at

View File

@ -0,0 +1,232 @@
#include <functorch/csrc/TensorWrapper.h>
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/BatchedTensorImpl.h>
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at {
namespace functorch {
void dumpTensor(std::ostream& ss, const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
auto* batched = maybeGetBatchedImpl(tensor);
if (batched) {
ss << "Batched[" << batched->bdims() << ", ";
dumpTensor(ss, batched->value());
ss << "]";
return;
}
ss << "Tensor" << tensor.sizes();
return;
}
if (wrapped->is_alive()) {
ss << "Wrapper[";
} else {
ss << "Wrapper[";
}
if (wrapped->level().has_value()) {
ss << wrapped->level().value() << ", ";
} else {
ss << "dead, ";
}
dumpTensor(ss, wrapped->value());
ss << "]";
}
void TensorWrapper::refreshSizesAndStrides() {
auto dim = value_.dim();
auto sizes = value_.sizes();
auto strides = value_.strides();
sizes_and_strides_.resize(value_.dim());
for (int64_t i = 0; i < dim; i++) {
sizes_and_strides_.size_at_unchecked(i) = sizes[i];
sizes_and_strides_.stride_at_unchecked(i) = strides[i];
}
refresh_numel();
refresh_contiguous();
}
void dumpTensorCout(const Tensor& tensor) {
dumpTensor(std::cout, tensor);
std::cout << std::endl;
}
c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) {
// TODO: denylist non-cuda/cpu backends to avoid funny business
DispatchKeySet key_set;
if (tensor.is_cuda()) {
key_set = key_set.add(DispatchKey::CUDA);
key_set = key_set.add(DispatchKey::AutogradCUDA);
} else {
key_set = key_set.add(DispatchKey::CPU);
key_set = key_set.add(DispatchKey::AutogradCPU);
}
key_set = key_set.add(DispatchKey::TensorWrapper);
if (should_be_alive) {
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
} else {
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, std::make_shared<bool>(false));
}
}
Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) {
auto wrapped = maybeGetTensorWrapper(tensor);
if (wrapped) {
TORCH_INTERNAL_ASSERT(wrapped->level() < level);
}
// TODO: denylist non-cuda/cpu backends to avoid funny business
DispatchKeySet key_set;
if (tensor.is_cuda()) {
key_set = key_set.add(DispatchKey::CUDA);
key_set = key_set.add(DispatchKey::AutogradCUDA);
} else {
key_set = key_set.add(DispatchKey::CPU);
key_set = key_set.add(DispatchKey::AutogradCPU);
}
key_set = key_set.add(DispatchKey::TensorWrapper);
auto life_handle = getLifeHandleForLevel(level);
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle));
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::TensorWrapper));
return result;
}
bool TensorWrapper::is_alive() const {
return *is_alive_;
}
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
dest_impl->set_version_counter(version_counter);
// TODO: is this even right?
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return dest_impl;
}
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
dest_impl->set_version_counter(version_counter);
// TODO: is this even right?
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return dest_impl;
}
void TensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
TORCH_INTERNAL_ASSERT(false, "NYI");
}
TensorWrapper::TensorWrapper(
c10::DispatchKeySet key_set,
Tensor value,
int64_t level,
std::shared_ptr<bool> is_alive,
bool use_value_sizes_strides)
: TensorImpl(key_set, value.dtype(), value.device())
, value_(std::move(value))
, level_(level)
, is_alive_(std::move(is_alive))
{
TORCH_INTERNAL_ASSERT(value_.defined());
set_storage_access_should_throw();
// TODO: need to reset sizes/strides on mutation
TORCH_INTERNAL_ASSERT(use_value_sizes_strides);
refreshSizesAndStrides();
}
// The following are some internal inherited methods that we do not support.
// They should never get called.
void TensorWrapper::set_size(int64_t dim, int64_t new_size) {
TORCH_INTERNAL_ASSERT(false, "Can't set_size for TensorWrapper");
}
void TensorWrapper::set_stride(int64_t dim, int64_t new_stride) {
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for TensorWrapper");
}
void TensorWrapper::set_storage_offset(int64_t storage_offset) {
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper");
}
const char* TensorWrapper::tensorimpl_type_name() const {
return "TensorWrapper";
}
TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) {
if (!tensor.key_set().has(DispatchKey::TensorWrapper)) {
return nullptr;
}
return (TensorWrapper*)(tensor.unsafeGetTensorImpl());
}
static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
std::function<Tensor(const Tensor&)> func) {
TORCH_INTERNAL_ASSERT(begin >= 0);
TORCH_INTERNAL_ASSERT(end >= 0);
TORCH_INTERNAL_ASSERT(begin <= end);
for (int64_t idx = begin; idx < end; idx++) {
auto ivalue = args[idx];
if (ivalue.isTensorList()) {
TORCH_INTERNAL_ASSERT(false, "NYI: TensorList");
}
if (!ivalue.isTensor()) {
continue;
}
Tensor value = ivalue.toTensor();
Tensor replacement = func(value);
args[idx] = replacement; // TODO: std::move?
if (ivalue.toTensor().defined()) {
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
}
}
}
static Tensor unwrapIfDead(const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
return tensor;
}
if (wrapped->is_alive()) {
return tensor;
}
return wrapped->value();
}
void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto args_size = op.schema().arguments().size();
int64_t unwrapped_count = 0;
auto unwrapIfDeadAndIncrement = [&](const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
return tensor;
}
if (wrapped->is_alive()) {
return tensor;
}
unwrapped_count++;
return wrapped->value();
};
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrapIfDeadAndIncrement);
TORCH_INTERNAL_ASSERT(unwrapped_count > 0, "Should have at least one dead wrapper");
// re-dispatch
op.callBoxed(stack);
}
// TensorWrapper backend fallback: Unwrap and fallthrough.
TORCH_LIBRARY_IMPL(_, TensorWrapper, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>());
}
}
} // namespace at

View File

@ -0,0 +1,61 @@
#pragma once
#include <ATen/Tensor.h>
namespace at {
namespace functorch {
struct TORCH_API TensorWrapper : public c10::TensorImpl {
explicit TensorWrapper(
c10::DispatchKeySet key_set,
Tensor value,
int64_t level,
std::shared_ptr<bool> is_alive,
bool use_value_sizes_strides = true);
// Override a bunch of methods inherited from TensorImpl to return error messages
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
void refreshSizesAndStrides();
const Tensor& value() const {
return value_;
}
optional<int64_t> level() const {
if (is_alive()) {
return level_;
}
return {};
}
bool is_alive() const;
// Overrides necessary for autograd
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
private:
const char* tensorimpl_type_name() const override;
Tensor value_;
int64_t level_;
// When we exit the level, this wrapper may be marked as "not alive".
// Wrappers that are not alive:
// 1) May still have autograd metadata on them
// 2) Forward dispatches to the underlying value()
std::shared_ptr<bool> is_alive_;
};
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level);
TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
TORCH_API void dumpTensorCout(const Tensor& tensor);
}
} // namespace at

View File

@ -0,0 +1,60 @@
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/VmapMode.h>
#include <functorch/csrc/Constants.h>
namespace at {
namespace functorch {
namespace impl {
/// thread_local is a feature that is not enabled by Caffe2 mobile
/// build (e.g. iOS). Therefore, we only provide `at::VmapMode`
/// when we are not in mobile build or when FEATURE_TORCH_MOBILE
/// is on.
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
thread_local int64_t VmapMode_current_vmap_level = 0;
int64_t VmapMode::current_vmap_level() {
return VmapMode_current_vmap_level;
}
int64_t VmapMode::increment_nesting() {
VmapMode_current_vmap_level++;
auto level = initAndPushDynamicLayer(kBatchedKey);
if (VmapMode_current_vmap_level == 1) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
}
return level;
}
int64_t VmapMode::decrement_nesting() {
VmapMode_current_vmap_level--;
auto layer = popDynamicLayerAndDeleteMetadata();
TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
if (VmapMode_current_vmap_level == 0) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
}
// TODO: this return value should never be used
return VmapMode_current_vmap_level;
}
#else
int64_t VmapMode::current_nesting_level() {
TORCH_CHECK(false, "VmapMode is not supported on mobile");
}
int64_t VmapMode::increment_nesting() {
TORCH_CHECK(false, "VmapMode is not supported on mobile");
}
int64_t VmapMode::decrement_nesting() {
TORCH_CHECK(false, "VmapMode is not supported on mobile");
}
#endif
} // namespace impl
}
} // namespace at

View File

@ -0,0 +1,30 @@
#pragma once
#include <c10/core/impl/LocalDispatchKeySet.h>
namespace at {
namespace functorch {
namespace impl {
// VmapMode contains a thread local count of how many nested vmaps
// we are currently inside. That number is known as the `vmap level`.
// VmapMode is used in the implementation of the Python `torch.vmap` API.
//
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
struct TORCH_API VmapMode {
// Returns the vmap level, aka the count of how many nested vmaps we're in.
static int64_t current_vmap_level();
// Increment the count of nested vmaps. If this causes the vmap level to be
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
static int64_t increment_nesting();
// Decrements the count of nested vmaps. If this causes the vmap level to be
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
static int64_t decrement_nesting();
};
} // namespace impl
}
} // namespace at

View File

@ -0,0 +1,323 @@
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/DynamicLayer.h>
#include <ATen/ATen.h>
namespace at {
namespace functorch {
// Checks if the batch dims in `bdims` appear at the front of the tensor.
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
for (int64_t idx = 0; idx < bdims.size(); idx++) {
if (bdims[idx].dim() != idx) {
return false;
}
}
return true;
}
// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
// and then returns a physical version of the Tensor.
static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
auto bdims = batched->bdims();
const Tensor& physical_tensor = batched->value();
if (areBdimsAtFrontInOrder(bdims)) {
return physical_tensor;
}
const auto sizes = physical_tensor.sizes();
VmapDimVector permutation(sizes.size(), 0);
permutation.reserve(sizes.size());
const auto is_bdim = createBatchDimBitset(bdims);
int64_t idx = 0;
for (const auto& bdim : bdims) {
permutation[idx++] = bdim.dim();
}
for (int64_t ptr = 0; idx < sizes.size(); ptr++) {
if (is_bdim[ptr]) {
continue;
}
permutation[idx++] = ptr;
}
return physical_tensor.permute(permutation);
}
VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
TORCH_INTERNAL_ASSERT(
batched,
"logicalToPhysical(tensor) should only be passed a BatchedTensor");
return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
}
int64_t VmapPhysicalView::numBatchDims() const {
return levels_.count();
}
int64_t VmapPhysicalView::numLogicalDims() const {
return /*physical*/tensor_.dim() - numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
auto logical_ndim = numLogicalDims();
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
VmapDimVector result;
result.reserve(logical_ndim);
for (auto dim : logical_dims) {
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
}
return result;
}
int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
auto logical_ndim = numLogicalDims();
return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
VmapDimVector result;
result.reserve(logical_shape.size() + numBatchDims());
auto tensor_sizes = tensor_.sizes();
result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
result.insert(result.end(), logical_shape.begin(), logical_shape.end());
return result;
}
static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
BatchDims bdims;
int64_t dim = 0;
for (int64_t level = 0; level < kVmapNumLevels; level++) {
if (!levels_bitset[level]) {
continue;
}
bdims.emplace_back(level, dim++);
}
return bdims;
}
// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
// with all vmapped dimensions permuted to the front, if they exist, and a
// bitset of vmap levels that were present in the tensor.
static std::pair<Tensor,std::bitset<kVmapNumLevels>>
getPhysicalTensorAndLevels(const Tensor& self) {
auto* batched = maybeGetBatchedImpl(self);
if (batched) {
return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
}
return {self, 0};
}
// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
// such that it has a batch dimension for each level in `requested_levels`
// and `requested_example_dim` number of non-batch-dimensions.
//
// This function is useful in preparing physical views on tensors that can
// then be passed into broadcasting operations. For example, when adding
// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
// batch dimensions, we must align the batch dimensions and non-batch-dimensions
// (henceforth referred to as the "example" dimensions) separately to produce
// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
//
// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
//
// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
// level 1 and another extra dimension to pad the example dimensions to 2.
//
// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, B1, 2, 3]
static Tensor alignBatchDimsAtFront(
const Tensor& self,
std::bitset<kVmapNumLevels> requested_levels,
int64_t requested_example_dim) {
Tensor physical_tensor;
std::bitset<kVmapNumLevels> tensor_levels;
std::tie(physical_tensor, tensor_levels) = getPhysicalTensorAndLevels(self);
TORCH_INTERNAL_ASSERT(
(tensor_levels | requested_levels) == requested_levels,
"`requested_levels` must be a superset of `self`'s levels");
auto physical_sizes = physical_tensor.sizes();
auto tensor_example_dim = physical_sizes.size() - /*num_batch_dims*/tensor_levels.count();
TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
// Optimization: no need to do another view if the physical tensor is
// already the correct shape
return physical_tensor;
}
VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
// align the example dims (non-bdims dims) first
// aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
std::copy(
physical_sizes.rbegin(),
physical_sizes.rbegin() + tensor_example_dim,
aligned_sizes.rbegin());
// align the bdims
int64_t level = 0;
int64_t tensor_dim = 0;
for (int64_t bdim = 0; bdim < requested_levels.count(); bdim++) {
// Determine the level of the bdim
while (!requested_levels[level]) level++;
if (tensor_levels[level]) {
aligned_sizes[bdim] = physical_sizes[tensor_dim++];
}
level++;
}
return physical_tensor.view(aligned_sizes);
}
static Tensor moveDimToFrontAndExpand(Tensor tensor, optional<int64_t> dim, int64_t size) {
if (dim) {
tensor = tensor.movedim(*dim, 0);
} else {
tensor = tensor.unsqueeze(0);
auto expanded_sizes = tensor.sizes().vec();
expanded_sizes[0] = size;
tensor = tensor.expand(expanded_sizes);
}
return tensor;
}
// The algorithm is as follows:
// 1. Figure out what all of the collective levels in `logical_tensors` is.
// 2. Move all batch dims to the front of the tensors and add extra dims
// of size 1. At this point, every tensor will have a dimension for
// each of the collective levels.
// 3. Compute the batch_sizes.
// 4. Expand each physical tensor so that they have output batch size equal
// to `batch_sizes`
VmapPhysicalViewVec
MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
auto cur_level = maybeCurrentDynamicLayer().value().layerId();
auto bdim_size = -1;
// Figure out the batch size first
for (const auto& logical_tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (!batched) {
continue;
}
TORCH_INTERNAL_ASSERT(batched->bdims().size() == 1);
if (batched->bdims().back().level() != cur_level) {
continue;
}
bdim_size = batched->value().size(batched->bdims().back().dim());
}
TORCH_INTERNAL_ASSERT(bdim_size != -1);
std::bitset<kVmapNumLevels> levels;
levels[cur_level] = 1;
VmapPhysicalViewVec result;
for (const auto& logical_tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (!batched || (batched->bdims().back().level() != cur_level)) {
// Unsqueeze dim 0, expand it to the correct shape
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto value = moveDimToFrontAndExpand(logical_tensor, {}, bdim_size);
result.emplace_back(std::move(value), levels);
continue;
}
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto physical = batched->value();
auto value = moveDimToFrontAndExpand(physical, batched->bdims().back().dim(), bdim_size);
result.emplace_back(std::move(value), levels);
}
return result;
}
static std::pair<std::bitset<kVmapNumLevels>,int64_t>
getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
TORCH_INTERNAL_ASSERT(logical_tensors.size() > 0);
std::bitset<kVmapNumLevels> levels;
int64_t largest_logical_dim = -1;
for (const auto& tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(tensor);
if (batched) {
levels = levels | createVmapLevelsBitset(batched->bdims());
}
auto tensor_logical_dim = /*logical dim*/tensor.dim();
if (tensor_logical_dim > largest_logical_dim) {
largest_logical_dim = tensor_logical_dim;
}
}
return { levels, largest_logical_dim };
}
static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, optional<int64_t> dim, int64_t example_ndim) {
if (dim) {
tensor = tensor.movedim(*dim, 0);
} else {
tensor = tensor.unsqueeze(0);
}
auto ndim = tensor.dim() - 1;
for (int64_t i = 0; i < example_ndim - ndim; i++) {
tensor = tensor.unsqueeze(1);
}
return tensor;
}
VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
auto cur_level = maybeCurrentDynamicLayer().value().layerId();
auto bdim_size = -1;
// Figure out the batch size first
for (const auto& logical_tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (!batched || (batched->bdims().back().level() != cur_level)) {
continue;
}
bdim_size = batched->value().size(batched->bdims().back().dim());
}
TORCH_INTERNAL_ASSERT(bdim_size != -1);
std::bitset<kVmapNumLevels> levels;
levels[cur_level] = 1;
// figure out the example ndim
int64_t max_example_dim = -1;
for (const auto& logical_tensor : logical_tensors) {
max_example_dim = std::max(logical_tensor.dim(), max_example_dim);
}
VmapPhysicalViewVec result;
for (const auto& logical_tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (!batched || (batched->bdims().back().level() != cur_level)) {
// Unsqueeze dim 0, expand it to the correct shape
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto value = moveDimToFrontAndUnsqueeze(logical_tensor, {}, max_example_dim);
result.emplace_back(std::move(value), levels);
continue;
}
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto physical = batched->value();
auto value = moveDimToFrontAndUnsqueeze(physical, batched->bdims().back().dim(), max_example_dim);
result.emplace_back(std::move(value), levels);
}
return result;
}
VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
return VmapPhysicalToLogicalMap(levels_);
}
Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
}
void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) {
physical_tensors[idx] = apply(physical_tensors[idx]);
}
}
}
} // namespace at

View File

@ -0,0 +1,177 @@
#pragma once
#include <functorch/csrc/BatchedTensorImpl.h>
namespace at {
namespace functorch {
// This file contains abstractions used for transforming *logical* vmap arguments
// into *physical* arguments. (Keep reading for definitions of these terms).
// NOTE: [Logical vs physical args]
// Consider the following vmap.
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
// with batch dims 0 and 2:
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
//
// We say the *logical* view of the tensor has size [3] -- tensors inside
// `func` appear to have size [3].
// However, the *physical* underlying tensor (the one passed to vmap) has size
// [2, 3, 4].
//
// This notion of logical vs physical also extends to non-tensor arguments.
// Consider the previous tensor; let's assume the user called
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
// dimension they are reducing over is dim 0 but the physical dim is dim 1
// (the first non-batch dimension)
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
struct VmapPhysicalView;
// Most PyTorch operators take 4 or fewer inputs.
constexpr int64_t kVmapTransformStaticInputSize = 4;
using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
// Pytorch generally advertises good performance for <= 5 dims.
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
// dimensions to get 8. Adjust this number as necessary
constexpr int64_t kVmapStaticDimVecSize = 8;
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
// NOTE: [What is an VmapTransform?]
// An *VmapTransform* converts logical views of tensors to physical views.
//
// Batching rules use VmapTransforms to convert logical arguments to
// physical arguments, then call one or more at:: operator that handles the
// physical arguments, and then converts the physical result back to a logical
// argument.
// VmapTransform for operators that take tensors with multiple batch dims.
// Given one or more logical views on Tensors, `logicalToPhysical`
// permutes all of the batch dims to the front of the tensor, aligns
// and expands the batch dims to match each other (according to their `level`),
// and returns a VmapPhysicalView on the tensor(s).
struct TORCH_API MultiBatchVmapTransform {
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
};
// VmapTransform for operators that broadcast all inputs.
// Given some logical views on Tensors, `logicalToPhysical`:
// - permutes all of the batch dims to the front of the tensors
// - aligns all the batch dims to the collective levels of all of the tensors.
// If a tensor does not have a batch dim for a vmap level, then it receives
// a size-one dimension for said level.
// - aligns the non-batch dims to have the same dimensionality, adding extra
// size-1 dimensions in between the batch dimensions and the non-batch dimensions
// so that the batch dimensions are lined up from the right.
//
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
// of size (B, 1, 2) and (B, 3, 2).
//
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
// actually *need* to return a tensor of size (1, 2) for the second tensor
// because the broadcasting operation takes care of that for us, but we do
// it anyways to keep things simple.
struct TORCH_API BroadcastingVmapTransform {
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
};
// Forward declared, if you're reading this file head to toe, don't worry about
// it yet.
struct VmapPhysicalToLogicalMap;
// NOTE: [What is a VmapPhysicalView?]
// VmapPhysicalView represents a physical view on a Tensor.
//
// One can use it to further convert logical dimension indices, logical shapes,
// and more to their physical variants, or convert a new (physical) tensor into
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
//
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
// the front and some levels that correspond to said batch dimensions.
//
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
//
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
// than or equal to 3.
// bitset: 010100
// ^
// |
// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
: levels_(levels), tensor_(tensor) {
// TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
}
Tensor& tensor() { return tensor_; }
const Tensor& tensor() const { return tensor_; }
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
//
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
//
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
// This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const;
// Returns a VmapPhysicalToLogicalMap object. This can be used for
// mapping a physical tensor to a new logical tensor (BatchedTensor)
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
// Maps a logical shape to a physical shape by pre-pending the batch
// sizes to the logical shape.
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
int64_t numBatchDims() const;
private:
int64_t numLogicalDims() const;
std::bitset<kVmapNumLevels> levels_;
Tensor tensor_;
};
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
// to a logical one (BatchedTensor). It holds some levels that are used to do the
// mapping and assumes that the batch dimensions in the physical tensor all
// occur at the front of the tensor.
struct TORCH_API VmapPhysicalToLogicalMap {
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
// Maps a physical tensor to a new logical tensor (BatchedTensor).
// Assumes that all of the "batch dimensions" are at the front
// of the physical tensor. For example, given:
// - x = rank-4 Tensor with size 2, 3, 5, 7
// - levels = (2, 4)
// Returns:
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
Tensor apply(const Tensor& physical_tensor) const;
// Given a vector of physical tensors,
// 1. maps each tensor to a new logical tensor. Assumes that all of the
// "batch dimensions" are at the front of the physical tensors.
// 2. stores the new logical tensors back into the passed-in vector. This is
// to avoid additional dynamic allocations.
void applyInplace(std::vector<Tensor>& physical_tensors) const;
std::bitset<kVmapNumLevels> levels_;
};
}
} // namespace at

View File

@ -0,0 +1,176 @@
#include <torch/extension.h>
#include <ATen/WrapDimUtils.h>
#include <functorch/csrc/TensorWrapper.h>
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/BatchedTensorImpl.h>
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/VmapMode.h>
namespace at {
namespace functorch {
static bool has_level(const Tensor& self, int64_t level) {
const auto* batched = maybeGetBatchedImpl(self);
if (!batched) {
return false;
}
auto bdims = batched->bdims();
auto* it = std::find_if(bdims.begin(), bdims.end(), [&](const BatchDim& bdim) {
return bdim.level() == level;
});
return it != bdims.end();
}
Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
return addBatchDim(self, level, batch_dim);
}
static std::pair<Tensor,int64_t> remove_existing_batch_dim(
const BatchedTensorImpl* batched, int64_t level) {
auto bdims = batched->bdims();
if (bdims.size() == 1) {
TORCH_INTERNAL_ASSERT(bdims[0].level() == level);
return std::make_pair(batched->value(), bdims[0].dim());
}
BatchDims new_bdims;
int64_t newly_exposed_physical_dim = -1;
new_bdims.reserve(bdims.size() - 1);
for (const auto& bdim : bdims) {
if (bdim.level() == level) {
newly_exposed_physical_dim = bdim.dim();
} else {
new_bdims.push_back(bdim);
}
}
// Because a BatchDim with level `level` must exist inside `batched,
// we should have found a `newly_exposed_logical_dim`.
TORCH_INTERNAL_ASSERT(newly_exposed_physical_dim != -1);
int64_t num_batch_dims_before_newly_exposed_physical_dim = std::count_if(
new_bdims.begin(), new_bdims.end(),
[&](const BatchDim& bdim) {
return bdim.dim() < newly_exposed_physical_dim;
});
int64_t newly_exposed_logical_dim =
newly_exposed_physical_dim - num_batch_dims_before_newly_exposed_physical_dim;
auto result_tensor = makeBatched(batched->value(), std::move(new_bdims));
return std::make_pair(std::move(result_tensor), newly_exposed_logical_dim);
}
// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
// while preserving the order of other existing dimensions.
// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
// When we do, replace the following with it.
static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) {
auto logical_dim = self.dim();
src = maybe_wrap_dim(src, logical_dim);
dst = maybe_wrap_dim(dst, logical_dim);
if (src == dst) {
return self;
}
VmapDimVector permutation;
permutation.reserve(logical_dim);
for (int64_t dim = 0; dim < logical_dim; dim++) {
if (dim == src) {
continue;
}
permutation.push_back(dim);
}
permutation.insert(permutation.begin() + dst, src);
return self.permute(permutation);
}
// Removes the batch dim with level `level` from `self`. If this causes the
// last batch dim to be removed from a BatchedTensor, then this returns a
// regular Tensor.
//
// If the `level` of the batch dim to remove does not exist in `self`, then we
// add the batch dim in. This can happen if `self` didn't interact with a tensor
// inside the vmap level, for example,
// self = torch.randn(3)
// y = torch.randn(5)
// out = vmap(lambda x: vmap(lambda y: x)(y))(self)
// assert out.shape == (3, 5)
// Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
// corresponding to the *outer* vmap level and it doesn't have any dimensions that
// correspond to the inner vmap level so we need to create one for the user.
//
// `out_dim` controls where we should put the batch dimension in the output tensor.
Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) {
if (!has_level(self, level)) {
auto self_sizes = self.sizes();
VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
return self.expand(expanded_sizes);
}
// Must be batched if has_level(self, /*any_level*/)
const auto* batched = maybeGetBatchedImpl(self);
TORCH_INTERNAL_ASSERT(batched != nullptr);
Tensor self_without_bdim;
int64_t newly_exposed_logical_dim;
std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level);
return _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
}
Tensor _wrap_for_grad(const Tensor& self, int64_t level) {
// NB: different behavior inside??
// return self;
// TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self));
// TORCH_INTERNAL_ASSERT(self.has_storage());
return makeTensorWrapper(self, level);
}
Tensor _unwrap_for_grad(const Tensor& self, int64_t level) {
auto* result = maybeGetTensorWrapper(self);
if (!result) {
return self;
}
TORCH_INTERNAL_ASSERT(result->level().has_value());
if (result->level() == level) {
return result->value();
}
return self;
}
int64_t dlevel(const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
return 0;
}
if (!wrapped->is_alive()) {
return -1;
}
return wrapped->level().value();
}
bool dump_tensor(const Tensor& self) {
dumpTensorCout(self);
return true;
}
int64_t _grad_increment_nesting() {
return initAndPushDynamicLayer(at::DispatchKey::Autograd);
}
int64_t _grad_decrement_nesting() {
return popDynamicLayerAndDeleteMetadata().layerId();
}
} // namespace functorch
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim");
m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim");
m.def("_vmapmode_increment_nesting", &at::functorch::impl::VmapMode::increment_nesting, "add batch dim");
m.def("_vmapmode_decrement_nesting", &at::functorch::impl::VmapMode::decrement_nesting, "remove batch dim");
m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim");
m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");
m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
m.def("dlevel", &at::functorch::dlevel, "add batch dim");
m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");
}

71
functorch/setup.py Normal file
View File

@ -0,0 +1,71 @@
import distutils
import shutil
import glob
import os
from setuptools import setup, find_packages
from torch.utils.cpp_extension import (
CppExtension,
BuildExtension,
)
# class clean(distutils.command.clean.clean):
# def run(self):
# with open(".gitignore", "r") as f:
# ignores = f.read()
# for wildcard in filter(None, ignores.split("\n")):
# for filename in glob.glob(wildcard):
# try:
# os.remove(filename)
# except OSError:
# shutil.rmtree(filename, ignore_errors=True)
#
# # It's an old-style class in Python 2.7...
# distutils.command.clean.clean.run(self)
def get_extensions():
extension = CppExtension
define_macros = []
extra_link_args = []
extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
if int(os.environ.get("DEBUG", 0)):
extra_compile_args = {
"cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]}
extra_link_args = ["-O0", "-g"]
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "functorch", "csrc")
extension_sources = set(
os.path.join(extensions_dir, p)
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
)
sources = list(extension_sources)
include_dirs = [extensions_dir]
ext_modules = [
extension(
"functorch._C",
sources,
include_dirs=[this_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
return ext_modules
setup(
name='functorch',
url="https://github.com/zou3519/functorch",
packages=find_packages(),
ext_modules=get_extensions(),
cmdclass={
# "clean": clean,
"build_ext": BuildExtension
})

View File

@ -0,0 +1,438 @@
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
import functools
import itertools
import warnings
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
import types
from functools import partial
import functorch
from functorch import grad, vjp, vmap, make_functional, jacrev
class TestGradTransform(TestCase):
def test_primitive(self):
x = torch.randn([])
result = grad(torch.sin)(x)
self.assertEqual(result, torch.cos(x))
def test_composite_simple(self):
x = torch.randn(2, 3, 4)
result = grad(lambda x: torch.flatten(x).sum())(x)
self.assertEqual(result, torch.ones_like(x))
def test_composite_complicated(self):
x = torch.randn(3)
y = torch.randn(3, 5)
def foo(x, y):
result = x @ y
return result.sum()
result = grad(foo)(x, y)
x.requires_grad_()
out = foo(x, y)
expected, = torch.autograd.grad(out, x)
self.assertEqual(result, expected)
def test_composite_two_ops(self):
N, C = 2, 5
y = torch.randn(N, C)
targets = torch.randint(0, C, (N,))
def foo(y, targets):
return F.cross_entropy(y, targets)
result = grad(foo)(y, targets)
y.requires_grad_()
expected, = torch.autograd.grad(foo(y, targets), y)
self.assertEqual(result, expected)
def _test_attributes(self, get_attr_lambda):
x = torch.randn(2, 3, 5, dtype=torch.double)
expected = get_attr_lambda(x)
def foo(x):
self.assertEqual(get_attr_lambda(x), expected)
return x.sum()
grad(foo)(x)
def test_shape(self):
self._test_attributes(lambda x: x.shape)
def test_dtype(self):
self._test_attributes(lambda x: x.dtype)
def test_is_cuda(self):
self._test_attributes(lambda x: x.is_cuda)
def test_numel(self):
self._test_attributes(lambda x: x.numel())
def test_inplace(self):
x = torch.randn([])
def foo(x):
return x.clone().sin_()
result = grad(foo)(x)
self.assertEqual(result, x.cos())
def test_inplace_on_view(self):
x = torch.randn(3)
def foo(x):
y = x.clone()
y0 = y[0]
y0.sin_()
return y.sum()
result = grad(foo)(x)
x.requires_grad_()
out = foo(x)
expected, = torch.autograd.grad(out, x)
self.assertEqual(result, expected)
def test_inplace_on_view_base(self):
x = torch.randn(3)
def foo(x):
y = x.clone()
y0 = y[0]
y.sin_()
return y0
result = grad(foo)(x)
x.requires_grad_()
out = foo(x)
expected, = torch.autograd.grad(out, x)
self.assertEqual(result, expected)
def test_nesting_simple(self):
x = torch.randn([])
result = grad(grad(torch.sin))(x)
self.assertEqual(result, -torch.sin(x))
def test_escaped_wrappers_are_marked_as_dead(self):
x = torch.randn([])
escaped = []
def foo(x):
y = x.sin()
escaped.append(y)
return y
result = grad(foo)(x)
self.assertEqual(escaped[0].dlevel(), -1)
def test_escaped_wrappers_are_ignored(self):
x = torch.randn([])
escaped = []
def foo(x):
y = x.sin()
escaped.append(y)
return y
result = grad(foo)(x)
something = escaped[0].sum()
self.assertEqual(something.dlevel(), 0)
self.assertEqual(something, x.sin().sum())
def test_vjp(self):
x = torch.randn([])
out, vjp_fn = vjp(torch.sin, x)
self.assertEqual(out, x.sin())
v = torch.randn([])
result, = vjp_fn(v)
self.assertEqual(result, v * x.cos())
def test_composed_with_autograd(self):
x = torch.randn([], requires_grad=True)
y = grad(torch.sin)(x)
result, = torch.autograd.grad(y, x)
self.assertEqual(result, -x.sin())
def test_grad_of_vjp_composition(self):
x = torch.randn([])
y = torch.randn([])
def foo(x, y):
out, vjp_fn = vjp(torch.sin, x)
return grad(lambda y: vjp_fn(y)[0])(y)
result = foo(x, y)
expected = x.cos()
self.assertEqual(result, expected)
def test_vjp_of_grad_composition(self):
x = torch.randn([])
y = torch.randn([])
def foo(x, y):
out, vjp_fn = vjp(grad(torch.sin), x)
return vjp_fn(y)[0]
result = foo(x, y)
expected = -y * x.sin()
self.assertEqual(result, expected)
def test_grad_of_vjp_of_grad_composition(self):
x = torch.randn([])
y = torch.randn([])
def foo(x, y):
df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
return grad(lambda y: vjp_fn(y)[0])(y)
result = foo(x, y)
expected = x.cos()
self.assertEqual(result, expected)
def test_views(self):
x = torch.randn([], requires_grad=True)
y = torch.randn([], requires_grad=True)
def silly_sin(x):
x = x.view([])
x = x.sin()
return x
def foo(x, y):
z1 = grad(silly_sin)(x)
z2 = torch.cos(y)
return z1 + z2
result = foo(x, y)
grads = torch.autograd.grad(result, [x, y])
self.assertEqual(grads[0], -x.sin())
self.assertEqual(grads[1], -y.sin())
def test_view_inplace_simple(self):
def foo(x):
x = x.clone()
x.view([]).sin_()
return x
x = torch.randn([], requires_grad=True)
result = grad(foo)(x)
self.assertEqual(result, x.cos())
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self):
def compute_loss(weight, x, t):
x = x.mm(weight)
y = x.squeeze_(0)
return (y - t).sum()
weight = torch.randn(16, 2)
x = torch.randn(64, 1, 16)
t = torch.randn(64, 2)
result = vmap(partial(grad(compute_loss), weight))(x, t)
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
expected = torch.stack(expected)
# TODO: Check if the rtol is a problem
self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_new_zeros_materializes_tensor(self):
N = 3
C = 5
def foo(x, y):
result = x.new_zeros((C,))
result.copy_(y)
return result.sum()
x = torch.randn(N)
y = torch.randn(N, C)
result = vmap(grad(foo))(x, y)
def test_per_sample_grads_simple(self):
def compute_loss(weight, x, t):
y = x @ weight
return ((y - t) ** 2).sum()
weight = torch.randn(16, 2)
x = torch.randn(64, 16)
t = torch.randn(64, 2)
result = vmap(partial(grad(compute_loss), weight))(x, t)
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
expected = torch.stack(expected)
# TODO: Check if the rtol is a problem
self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_per_sample_grads_embeddingnet(self):
class SampleNet(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.emb = nn.Embedding(vocab_size, 16)
self.fc1 = nn.Linear(16, 16)
self.fc2 = nn.Linear(16, 2)
def forward(self, x):
x = self.emb(x)
x = torch.transpose(x, -1, -2)
x = torch.mean(x, -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
def name(self):
return "SampleNet"
# Create our inputs...
vocab_size = 1000
batch_shape = [64]
words_per_sentence = 5
data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence))
targets = torch.randint(0, 1, (*batch_shape,))
# Construct our module
net = SampleNet(vocab_size)
criterion = nn.CrossEntropyLoss()
params = dict(net.named_parameters())
weights, net_func, _ = make_functional(net)
def compute_loss(weights, data, target):
output = net_func(weights, (data,))
result = criterion(output, target)
return result
expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
expected = zip(*expected)
expected = tuple(torch.stack(shards) for shards in expected)
result = vmap(partial(grad(compute_loss), weights))(data, targets)
for r, e in zip(result, expected):
# TODO: Check if the rtol is a problem
self.assertEqual(r, e, atol=0, rtol=1e-4)
class TestJacrev(TestCase):
def test_simple(self):
x = torch.randn(3)
y = jacrev(torch.sin)(x)
expected = torch.diagflat(x.cos())
assert torch.allclose(y, expected)
def test_simple_not_flat(self):
x = torch.randn(2, 3)
y = jacrev(torch.sin)(x)
expected = torch.diagflat(x.view(-1).cos())
expected = expected.view(2, 3, 2, 3)
assert torch.allclose(y, expected)
def test_vmap_on_jacrev_simple(self):
x = torch.randn(2, 3)
y = vmap(jacrev(torch.sin))(x)
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
assert torch.allclose(y, expected)
def test_hessian_simple(self):
def foo(x):
return x.sin().sum()
x = torch.randn(3)
y = jacrev(jacrev(foo))(x)
expected = torch.diagflat(-x.sin())
assert torch.allclose(y, expected)
class TestComposability(TestCase):
def test_grad_grad(self):
x = torch.randn([])
y = grad(grad(torch.sin))(x)
self.assertEqual(y, -x.sin())
def test_grad_vmap(self):
def foo(x):
y = vmap(torch.sin)(x)
return y.sum()
x = torch.randn(3)
y = grad(foo)(x)
self.assertEqual(y, x.cos())
def test_grad_vjp(self):
x = torch.randn(3)
def foo(x):
_, vjp_fn = vjp(torch.sin, x)
return vjp_fn(x)[0].sum()
y = grad(foo)(x)
expected = grad(lambda x: (x * x.cos()).sum())(x)
self.assertEqual(y, expected)
def test_vmap_grad(self):
x = torch.randn(3)
y = vmap(grad(torch.sin))(x)
self.assertEqual(y, x.cos())
def test_vmap_vmap(self):
x = torch.randn(2, 3)
y = vmap(vmap(torch.sin))(x)
self.assertEqual(y, x.sin())
def test_vmap_vjp(self):
x = torch.randn(3)
_, vjp_fn = vjp(torch.sin, x)
def foo(x):
_, vjp_fn = vjp(torch.sin, x)
return vjp_fn(x)
y = vmap(foo)(x)
self.assertEqual(y, vjp_fn(x))
xs = torch.randn(5, 3)
expected = torch.stack([vjp_fn(x)[0] for x in xs])
self.assertEqual(vmap(lambda x: vjp_fn(x)[0])(xs), expected)
def test_vjp_grad(self):
x = torch.randn([])
y, vjp_fn = vjp(grad(torch.sin), x)
self.assertEqual(y, x.cos())
v = torch.randn([])
self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
def test_vjp_vmap(self):
x = torch.randn(3)
y, vjp_fn = vjp(vmap(torch.sin), x)
self.assertEqual(y, x.sin())
v = torch.randn(3)
self.assertEqual(vjp_fn(v)[0], x.cos() * v)
def test_vjp_vjp(self):
x = torch.randn(3)
y, vjp_fn = vjp(torch.sin, x)
self.assertEqual(y, x.sin())
y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x)
self.assertEqual(y, x * x.cos())
y = vjp_fn(x)[0]
# Honestly IDK what the result here is... but at least it runs
if __name__ == '__main__':
run_tests()

2516
functorch/test/test_vmap.py Normal file

File diff suppressed because it is too large Load Diff