mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] a lot of files
This commit is contained in:
4
functorch/.gitignore
vendored
Normal file
4
functorch/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
build/
|
||||
dist/
|
||||
functorch.egg-info/
|
||||
*__pycache__*
|
1
functorch/examples/.gitignore
vendored
Normal file
1
functorch/examples/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
cifar10/
|
2
functorch/examples/dp_cifar10/.gdbinit
Normal file
2
functorch/examples/dp_cifar10/.gdbinit
Normal file
@ -0,0 +1,2 @@
|
||||
catch throw
|
||||
r cifar10_transforms.py
|
436
functorch/examples/dp_cifar10/cifar10_expandweights.py
Normal file
436
functorch/examples/dp_cifar10/cifar10_expandweights.py
Normal file
@ -0,0 +1,436 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
# This is based off of zou3519/pytorch:expand_weights.
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
sample_grads = tuple(param.grad_sample for param in model.parameters())
|
||||
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
# step 4: assign the new grads, delete the sample grads
|
||||
for param, param_grad in zip(model.parameters(), grads):
|
||||
param.grad = param_grad
|
||||
del param.grad_sample
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads (provided by pytorch)
|
||||
# loss.backward() populates the grad_sample attribute of each param
|
||||
with model.compute_per_sample_grads(batch_size=images.shape[0]):
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
# Opacus implements this but I wrote a custom one to show how this would work.
|
||||
# This deletes the grad_sample attributes and populates the grad attributes
|
||||
clip_and_accumulate_and_add_noise(
|
||||
model, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
405
functorch/examples/dp_cifar10/cifar10_opacus.py
Normal file
405
functorch/examples/dp_cifar10/cifar10_opacus.py
Normal file
@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# compute gradient and do SGD step
|
||||
loss.backward()
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.disable_dp:
|
||||
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
|
||||
args.delta
|
||||
)
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
default=256,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 256), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
475
functorch/examples/dp_cifar10/cifar10_transforms.py
Normal file
475
functorch/examples/dp_cifar10/cifar10_transforms.py
Normal file
@ -0,0 +1,475 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from make_functional import make_functional, load_weights
|
||||
from torch.eager_transforms import vmap, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
sample_grads = tuple(param.grad_sample for param in model.parameters())
|
||||
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
# step 4: assign the new grads, delete the sample grads
|
||||
for param, param_grad in zip(model.parameters(), grads):
|
||||
param.grad = param_grad
|
||||
del param.grad_sample
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
|
||||
# In order to use functional vmap+grad, we need to be able to
|
||||
# pass the weights to a model.
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
# To use vmap+grad to compute per-sample-grads, the forward pass
|
||||
# must be re-formulated on a single example.
|
||||
# We use the `grad` operator to compute forward+backward on a single example,
|
||||
# and finally `vmap` to do forward+backward on multiple examples.
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# `grad(f)` is a functional API that returns a function `f'` that
|
||||
# computes gradients by running both the forward and backward pass.
|
||||
# We want to extract some intermediate
|
||||
# values from the computation (i.e. the loss and output).
|
||||
#
|
||||
# To extract the loss, we use the `grad_with_value` API, that returns the
|
||||
# gradient of the weights w.r.t. the loss and the loss.
|
||||
#
|
||||
# To extract the output, we use the `has_aux=True` flag.
|
||||
# `has_aux=True` assumes that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second "auxiliary value"
|
||||
# is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
|
||||
# the loss, and the auxiliary value.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
def packed(weights, images, target):
|
||||
grads, loss, output = grads_loss_output(weights, images, target)
|
||||
result = tuple([*grads, loss, output])
|
||||
return result
|
||||
result = vmap(partial(packed, weights))(images, target)
|
||||
sample_grads, sample_loss, output, = result[:-2], result[-2], result[-1]
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that they're easier to manipulate
|
||||
load_weights(model, descriptors, weights)
|
||||
for grad_sample, weight in zip(sample_grads, model.parameters()):
|
||||
weight.grad_sample = grad_sample.detach()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
model, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
438
functorch/examples/dp_cifar10/cifar10_transforms.py.~1~
Normal file
438
functorch/examples/dp_cifar10/cifar10_transforms.py.~1~
Normal file
@ -0,0 +1,438 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from functools import partial
|
||||
from make_functional import make_functional, load_weights
|
||||
|
||||
# NB: The following might not exist depending on what you're using
|
||||
from torch import vmap
|
||||
from functional_utils import grad, grad_with_value
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
|
||||
# TODO(rzou): does the group norm work correctly?
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
[weight.requires_grad_(False) for weight in weights]
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
target = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, target)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean(0)
|
||||
|
||||
# Step 2: Clip the per-sample-grads and sum them to form grads
|
||||
|
||||
# TODO(rzou): Right now we just sum the grads. Instead we need to clip them.
|
||||
for sample_grad, weight in zip(grads, weights):
|
||||
weight_grad = sample_grad.sum(0)
|
||||
weight.grad = weight_grad
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
[weight.requires_grad_(True) for weight in weights]
|
||||
load_weights(model, descriptors, weights, as_params=True)
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.disable_dp:
|
||||
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
|
||||
args.delta
|
||||
)
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
default=256,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 256), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
491
functorch/examples/dp_cifar10/cifar10_transforms.py.~2~
Normal file
491
functorch/examples/dp_cifar10/cifar10_transforms.py.~2~
Normal file
@ -0,0 +1,491 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch import vmap
|
||||
from make_functional import make_functional, load_weights
|
||||
from functional_utils import grad, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
class CudaMemoryLeakCheck():
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
# initialize context & RNG to prevenikkt false positive detections
|
||||
# when the test is the first to initialize those
|
||||
from torch.testing._internal.common_cuda import initialize_cuda_context_rng
|
||||
initialize_cuda_context_rng()
|
||||
|
||||
@staticmethod
|
||||
def get_cuda_memory_usage():
|
||||
# we don't need CUDA synchronize because the statistics are not tracked at
|
||||
# actual freeing, but at when marking the block as free.
|
||||
num_devices = torch.cuda.device_count()
|
||||
import gc
|
||||
gc.collect()
|
||||
return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
|
||||
|
||||
def __enter__(self):
|
||||
self.befores = self.get_cuda_memory_usage()
|
||||
|
||||
def __exit__(self, exec_type, exec_value, traceback):
|
||||
# Don't check for leaks if an exception was thrown
|
||||
if exec_type is not None:
|
||||
return
|
||||
|
||||
afters = self.get_cuda_memory_usage()
|
||||
|
||||
for i, (before, after) in enumerate(zip(self.befores, afters)):
|
||||
if after - before == 0:
|
||||
continue
|
||||
raise RuntimeError(f'{self.name} leaked {after-before} bytes')
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
return grads
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
use_prototype = False
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# grad_with_value(f) returns a function that returns (1) the grad and
|
||||
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second is not to be
|
||||
# differentiated and further adds a 3rd output.
|
||||
#
|
||||
# We need to use `grad_with_value(..., has_aux=True)` because we do
|
||||
# some analyses on the returned loss and output.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
sample_grads, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
load_weights(model, descriptors, weights)
|
||||
|
||||
for weight_grad, weight in zip(grads, model.parameters()):
|
||||
weight.grad = weight_grad.detach()
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
457
functorch/examples/dp_cifar10/cifar10_transforms.py.~3~
Normal file
457
functorch/examples/dp_cifar10/cifar10_transforms.py.~3~
Normal file
@ -0,0 +1,457 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch import vmap
|
||||
from make_functional import make_functional, load_weights
|
||||
from functional_utils import grad, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
return grads
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
use_prototype = False
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# grad_with_value(f) returns a function that returns (1) the grad and
|
||||
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second is not to be
|
||||
# differentiated and further adds a 3rd output.
|
||||
#
|
||||
# We need to use `grad_with_value(..., has_aux=True)` because we do
|
||||
# some analyses on the returned loss and output.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
sample_grads, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
load_weights(model, descriptors, weights)
|
||||
|
||||
for weight_grad, weight in zip(grads, model.parameters()):
|
||||
weight.grad = weight_grad.detach()
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
456
functorch/examples/dp_cifar10/cifar10_transforms.py.~4~
Normal file
456
functorch/examples/dp_cifar10/cifar10_transforms.py.~4~
Normal file
@ -0,0 +1,456 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch import vmap
|
||||
from make_functional import make_functional, load_weights
|
||||
from functional_utils import grad, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
return grads
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# grad_with_value(f) returns a function that returns (1) the grad and
|
||||
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second is not to be
|
||||
# differentiated and further adds a 3rd output.
|
||||
#
|
||||
# We need to use `grad_with_value(..., has_aux=True)` because we do
|
||||
# some analyses on the returned loss and output.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
sample_grads, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
load_weights(model, descriptors, weights)
|
||||
|
||||
for weight_grad, weight in zip(grads, model.parameters()):
|
||||
weight.grad = weight_grad.detach()
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
71
functorch/examples/dp_cifar10/make_functional.py
Normal file
71
functorch/examples/dp_cifar10/make_functional.py
Normal file
@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'])
|
||||
"""
|
||||
if len(names) == 1:
|
||||
delattr(obj, names[0])
|
||||
else:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
||||
"""
|
||||
if len(names) == 1:
|
||||
setattr(obj, names[0], value)
|
||||
else:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
The weights must be re-loaded with `load_weights` before the model
|
||||
can be used again.
|
||||
Note that this function modifies the model in place and after this
|
||||
call, mod.parameters() will be empty.
|
||||
"""
|
||||
orig_params = tuple(mod.parameters())
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_parameters()):
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
names.append(name)
|
||||
|
||||
# Make params regular Tensors instead of nn.Parameter
|
||||
params = tuple(p for p in orig_params)
|
||||
return params, names
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
as Tensors. This means that mod.parameters() will still be empty after this call.
|
||||
"""
|
||||
for name, p in zip(names, params):
|
||||
if as_params:
|
||||
p = nn.Parameter(p)
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
mutable_model = copy.deepcopy(model)
|
||||
load_weights(mutable_model, descriptors, weights)
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, fun, descriptors
|
71
functorch/examples/dp_cifar10/make_functional.py.~1~
Normal file
71
functorch/examples/dp_cifar10/make_functional.py.~1~
Normal file
@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'])
|
||||
"""
|
||||
if len(names) == 1:
|
||||
delattr(obj, names[0])
|
||||
else:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
||||
"""
|
||||
if len(names) == 1:
|
||||
setattr(obj, names[0], value)
|
||||
else:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
The weights must be re-loaded with `load_weights` before the model
|
||||
can be used again.
|
||||
Note that this function modifies the model in place and after this
|
||||
call, mod.parameters() will be empty.
|
||||
"""
|
||||
orig_params = tuple(mod.parameters())
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_parameters()):
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
names.append(name)
|
||||
|
||||
# Make params regular Tensors instead of nn.Parameter
|
||||
params = tuple(p.detach().requires_grad_() for p in orig_params)
|
||||
return params, names
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
as Tensors. This means that mod.parameters() will still be empty after this call.
|
||||
"""
|
||||
for name, p in zip(names, params):
|
||||
if as_params:
|
||||
p = nn.Parameter(p)
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
mutable_model = copy.deepcopy(model)
|
||||
load_weights(mutable_model, descriptors, weights)
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, fun, descriptors
|
122
functorch/examples/ensembling/parallel_train.py
Normal file
122
functorch/examples/ensembling/parallel_train.py
Normal file
@ -0,0 +1,122 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import make_functional, grad_with_value, vmap
|
||||
|
||||
# Adapted from http://willwhitney.com/parallel-training-jax.html
|
||||
# GOAL: Demonstrate that it is possible to use eager-mode vmap
|
||||
# to parallelize training over models.
|
||||
|
||||
# NB: this code runs off of a branch on zou3519/pytorch:dynlayer
|
||||
|
||||
DEVICE = 'cpu'
|
||||
|
||||
# Step 1: Make some spirals
|
||||
def make_spirals(n_samples, noise_std=0., rotations=1.):
|
||||
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
|
||||
rs = ts ** 0.5
|
||||
thetas = rs * rotations * 2 * math.pi
|
||||
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
|
||||
labels = (signs > 0).to(torch.long)
|
||||
|
||||
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
|
||||
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
|
||||
points = torch.stack([xs, ys], dim=1)
|
||||
return points, labels
|
||||
|
||||
points, labels = make_spirals(100, noise_std=0.05)
|
||||
|
||||
|
||||
# Step 2: Define two-layer MLP and loss function
|
||||
class MLPClassifier(nn.Module):
|
||||
def __init__(self, hidden_dim=32, n_classes=2):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_classes = n_classes
|
||||
|
||||
self.fc1 = nn.Linear(2, self.hidden_dim)
|
||||
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = F.log_softmax(x, -1)
|
||||
return x
|
||||
|
||||
loss_fn = nn.NLLLoss()
|
||||
|
||||
# Step 3: Make the model functional(!!) and define a training function.
|
||||
# NB: this mechanism doesn't exist in PyTorch today, but we want it to:
|
||||
# https://github.com/pytorch/pytorch/issues/49171
|
||||
weights, func_model, _ = make_functional(MLPClassifier().to(DEVICE))
|
||||
|
||||
def train_step_fn(weights, batch, targets, lr=0.2):
|
||||
def compute_loss(weights, batch, targets):
|
||||
output = func_model(weights, (batch,))
|
||||
loss = loss_fn(output, targets)
|
||||
return loss
|
||||
|
||||
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
|
||||
|
||||
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
|
||||
# so we are going to re-implement SGD here.
|
||||
new_weights = []
|
||||
with torch.no_grad():
|
||||
for grad_weight, weight in zip(grad_weights, weights):
|
||||
new_weights.append(weight - grad_weight * lr)
|
||||
# NB: return looks weird because torch.vmap must return Tensors
|
||||
return (loss, *new_weights)
|
||||
|
||||
|
||||
def unpack(train_result):
|
||||
return train_result[0], train_result[1:]
|
||||
|
||||
# Step 4: Let's verify this actually trains.
|
||||
# We should see the loss decrease.
|
||||
def step4():
|
||||
global weights
|
||||
for i in range(2000):
|
||||
loss, weights = unpack(train_step_fn(weights, points, labels))
|
||||
if i % 100 == 0:
|
||||
print(loss)
|
||||
|
||||
step4()
|
||||
|
||||
# Step 5: We're ready for multiple models. Let's define an init_fn
|
||||
# that, given a number of models, returns to us all of the weights.
|
||||
def init_fn(num_models):
|
||||
models = tuple(MLPClassifier() for _ in range(num_models))
|
||||
weights = tuple(make_functional(model)[0] for model in models)
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
return weights
|
||||
|
||||
# Step 6: Now, can we try multiple models at the same time?
|
||||
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
|
||||
# on decreasing
|
||||
def step6():
|
||||
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
|
||||
batched_weights = init_fn(num_models=2)
|
||||
for i in range(2000):
|
||||
loss, batched_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))
|
||||
if i % 200 == 0:
|
||||
print(loss)
|
||||
|
||||
step6()
|
||||
|
||||
# Step 7: Now, the flaw with step 6 is that we were training on the same exact
|
||||
# data. This can lead to all of the models in the ensemble overfitting in the
|
||||
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
|
||||
# applies is to randomly subset the data in a way that the models do not recieve
|
||||
# exactly the same data in each training step!
|
||||
# Because the goal of this doc is to show that we can use eager-mode vmap to
|
||||
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
|
||||
|
||||
# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html
|
||||
# does, we used the following additional items that PyTorch does not have:
|
||||
# 1. NN module functional API that turns a module into a (state, state_less_fn) pair
|
||||
# 2. Functional optimizers
|
||||
# 3. A "functional" grad API (that effectively wraps autograd.grad)
|
||||
# 4. Composability between the functional grad API and torch.vmap.
|
2
functorch/examples/maml_omniglot/.gdbinit
Normal file
2
functorch/examples/maml_omniglot/.gdbinit
Normal file
@ -0,0 +1,2 @@
|
||||
catch throw
|
||||
r maml-omniglot-transforms.py
|
3
functorch/examples/maml_omniglot/.gitignore
vendored
Normal file
3
functorch/examples/maml_omniglot/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
omniglot/
|
||||
maml-accs.png
|
||||
|
17
functorch/examples/maml_omniglot/README.md
Normal file
17
functorch/examples/maml_omniglot/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Omniglot MAML examples
|
||||
|
||||
In this directory we've provided some examples of traning omniglot that reproduce the experiments from [the original MAML paper](https://arxiv.org/abs/1703.03400).
|
||||
|
||||
They can be run via `python {filename}`.
|
||||
|
||||
`maml-omniglot-higher.py` uses the [facebookresearch/higher](https://github.com/facebookresearch/higher) metalearning package and is the reference implementation. It runs all of its tasks sequentially.
|
||||
|
||||
`maml-omniglot-transforms.py` uses an experimental vmap (and functional grad) prototype. It runs all of its tasks in parallel. In theory this should lead to some speedups, but we haven't finished implementing all the rules for vmap that would actually make training faster.
|
||||
|
||||
`maml-omniglot-ptonly.py` is an implementation of `maml-omniglot-transforms.py` that runs all of its tasks sequentially (and also doesn't use the higher package).
|
||||
|
||||
The prototype vmap used for these experiments currently run off of a branch.
|
||||
We'd love some feedback on the prototype and encourage folks to try it out.
|
||||
It's a bit difficult to install, but here are some options:
|
||||
1. If you're on the FAIR cluster, we can share a path to a conda environment
|
||||
2. We are looking into building binaries using our branch and shipping them.
|
277
functorch/examples/maml_omniglot/maml-omniglot-higher.py
Executable file
277
functorch/examples/maml_omniglot/maml-omniglot-higher.py
Executable file
@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
|
||||
for few-shot Omniglot classification.
|
||||
For more details see the original MAML paper:
|
||||
https://arxiv.org/abs/1703.03400
|
||||
|
||||
This code has been modified from Jackie Loong's PyTorch MAML implementation:
|
||||
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
|
||||
|
||||
Our MAML++ fork and experiments are available at:
|
||||
https://github.com/bamos/HowToTrainYourMAMLPytorch
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import typing
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
plt.style.use('bmh')
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
import higher
|
||||
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--n_way', type=int, help='n way', default=5)
|
||||
argparser.add_argument(
|
||||
'--k_spt', type=int, help='k shot for support set', default=5)
|
||||
argparser.add_argument(
|
||||
'--k_qry', type=int, help='k shot for query set', default=15)
|
||||
argparser.add_argument(
|
||||
'--device', type=str, help='device', default='cuda')
|
||||
argparser.add_argument(
|
||||
'--task_num',
|
||||
type=int,
|
||||
help='meta batch size, namely task num',
|
||||
default=32)
|
||||
argparser.add_argument('--seed', type=int, help='random seed', default=1)
|
||||
args = argparser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Set up the Omniglot loader.
|
||||
device = args.device
|
||||
db = OmniglotNShot(
|
||||
'/tmp/omniglot-data',
|
||||
batchsz=args.task_num,
|
||||
n_way=args.n_way,
|
||||
k_shot=args.k_spt,
|
||||
k_query=args.k_qry,
|
||||
imgsz=28,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create a vanilla PyTorch neural network that will be
|
||||
# automatically monkey-patched by higher later.
|
||||
# Before higher, models could *not* be created like this
|
||||
# and the parameters needed to be manually updated and copied
|
||||
# for the updates.
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(1, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
|
||||
|
||||
log = []
|
||||
for epoch in range(100):
|
||||
train(db, net, device, meta_opt, epoch, log)
|
||||
test(db, net, device, epoch, log)
|
||||
plot(log)
|
||||
|
||||
|
||||
def train(db, net, device, meta_opt, epoch, log):
|
||||
net.train()
|
||||
n_train_iter = db.x_train.shape[0] // db.batchsz
|
||||
|
||||
for batch_idx in range(n_train_iter):
|
||||
start_time = time.time()
|
||||
# Sample a batch of support and query images and labels.
|
||||
x_spt, y_spt, x_qry, y_qry = db.next()
|
||||
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
querysz = x_qry.size(1)
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
# doesn't have to be duplicated between `train` and `test`?
|
||||
|
||||
# Initialize the inner optimizer to adapt the parameters to
|
||||
# the support set.
|
||||
n_inner_iter = 5
|
||||
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
|
||||
|
||||
qry_losses = []
|
||||
qry_accs = []
|
||||
meta_opt.zero_grad()
|
||||
for i in range(task_num):
|
||||
with higher.innerloop_ctx(
|
||||
net, inner_opt, copy_initial_weights=False
|
||||
) as (fnet, diffopt):
|
||||
# Optimize the likelihood of the support set by taking
|
||||
# gradient steps w.r.t. the model's parameters.
|
||||
# This adapts the model's meta-parameters to the task.
|
||||
# higher is able to automatically keep copies of
|
||||
# your network's parameters as they are being updated.
|
||||
for _ in range(n_inner_iter):
|
||||
spt_logits = fnet(x_spt[i])
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
diffopt.step(spt_loss)
|
||||
|
||||
# The final set of adapted parameters will induce some
|
||||
# final loss and accuracy on the query dataset.
|
||||
# These will be used to update the model's meta-parameters.
|
||||
qry_logits = fnet(x_qry[i])
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_acc = (qry_logits.argmax(
|
||||
dim=1) == y_qry[i]).sum().item() / querysz
|
||||
qry_accs.append(qry_acc)
|
||||
|
||||
# print([b.shape for b in fnet[1].buffers()])
|
||||
|
||||
# Update the model's meta-parameters to optimize the query
|
||||
# losses across all of the tasks sampled in this batch.
|
||||
# This unrolls through the gradient steps.
|
||||
qry_loss.backward()
|
||||
|
||||
meta_opt.step()
|
||||
qry_losses = sum(qry_losses) / task_num
|
||||
qry_accs = 100. * sum(qry_accs) / task_num
|
||||
i = epoch + float(batch_idx) / n_train_iter
|
||||
iter_time = time.time() - start_time
|
||||
if batch_idx % 4 == 0:
|
||||
print(
|
||||
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
|
||||
)
|
||||
|
||||
log.append({
|
||||
'epoch': i,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'train',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
def test(db, net, device, epoch, log):
|
||||
# Crucially in our testing procedure here, we do *not* fine-tune
|
||||
# the model during testing for simplicity.
|
||||
# Most research papers using MAML for this task do an extra
|
||||
# stage of fine-tuning here that should be added if you are
|
||||
# adapting this code for research.
|
||||
net.train()
|
||||
n_test_iter = db.x_test.shape[0] // db.batchsz
|
||||
|
||||
qry_losses = []
|
||||
qry_accs = []
|
||||
|
||||
for batch_idx in range(n_test_iter):
|
||||
x_spt, y_spt, x_qry, y_qry = db.next('test')
|
||||
|
||||
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
querysz = x_qry.size(1)
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
# doesn't have to be duplicated between `train` and `test`?
|
||||
n_inner_iter = 5
|
||||
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
|
||||
|
||||
for i in range(task_num):
|
||||
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
|
||||
# Optimize the likelihood of the support set by taking
|
||||
# gradient steps w.r.t. the model's parameters.
|
||||
# This adapts the model's meta-parameters to the task.
|
||||
for _ in range(n_inner_iter):
|
||||
spt_logits = fnet(x_spt[i])
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
diffopt.step(spt_loss)
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = fnet(x_qry[i]).detach()
|
||||
qry_loss = F.cross_entropy(
|
||||
qry_logits, y_qry[i], reduction='none')
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_accs.append(
|
||||
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
|
||||
qry_losses = torch.cat(qry_losses).mean().item()
|
||||
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
|
||||
print(
|
||||
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
|
||||
)
|
||||
log.append({
|
||||
'epoch': epoch + 1,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'test',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
def plot(log):
|
||||
# Generally you should pull your plotting code out of your training
|
||||
# script but we are doing it here for brevity.
|
||||
df = pd.DataFrame(log)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
train_df = df[df['mode'] == 'train']
|
||||
test_df = df[df['mode'] == 'test']
|
||||
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
|
||||
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Accuracy')
|
||||
ax.set_ylim(70, 100)
|
||||
fig.legend(ncol=2, loc='lower right')
|
||||
fig.tight_layout()
|
||||
fname = 'maml-accs.png'
|
||||
print(f'--- Plotting accuracy to {fname}')
|
||||
fig.savefig(fname)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# Won't need this after this PR is merged in:
|
||||
# https://github.com/pytorch/pytorch/pull/22245
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
270
functorch/examples/maml_omniglot/maml-omniglot-ptonly.py
Executable file
270
functorch/examples/maml_omniglot/maml-omniglot-ptonly.py
Executable file
@ -0,0 +1,270 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
|
||||
for few-shot Omniglot classification.
|
||||
For more details see the original MAML paper:
|
||||
https://arxiv.org/abs/1703.03400
|
||||
|
||||
This code has been modified from Jackie Loong's PyTorch MAML implementation:
|
||||
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
|
||||
|
||||
Our MAML++ fork and experiments are available at:
|
||||
https://github.com/bamos/HowToTrainYourMAMLPytorch
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import typing
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
plt.style.use('bmh')
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.eager_transforms import make_functional_with_buffers
|
||||
|
||||
import higher
|
||||
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--n_way', type=int, help='n way', default=5)
|
||||
argparser.add_argument(
|
||||
'--k_spt', type=int, help='k shot for support set', default=5)
|
||||
argparser.add_argument(
|
||||
'--k_qry', type=int, help='k shot for query set', default=15)
|
||||
argparser.add_argument(
|
||||
'--device', type=str, help='device', default='cuda')
|
||||
argparser.add_argument(
|
||||
'--task_num',
|
||||
type=int,
|
||||
help='meta batch size, namely task num',
|
||||
default=32)
|
||||
argparser.add_argument('--seed', type=int, help='random seed', default=1)
|
||||
args = argparser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Set up the Omniglot loader.
|
||||
device = args.device
|
||||
db = OmniglotNShot(
|
||||
'/tmp/omniglot-data',
|
||||
batchsz=args.task_num,
|
||||
n_way=args.n_way,
|
||||
k_shot=args.k_spt,
|
||||
k_query=args.k_qry,
|
||||
imgsz=28,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create a vanilla PyTorch neural network that will be
|
||||
# automatically monkey-patched by higher later.
|
||||
# Before higher, models could *not* be created like this
|
||||
# and the parameters needed to be manually updated and copied
|
||||
# for the updates.
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(1, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
|
||||
net.train()
|
||||
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
meta_opt = optim.Adam(params, lr=1e-3)
|
||||
|
||||
log = []
|
||||
for epoch in range(100):
|
||||
train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
|
||||
test(db, [params, buffers, fnet], device, epoch, log)
|
||||
plot(log)
|
||||
|
||||
|
||||
def train(db, net, device, meta_opt, epoch, log):
|
||||
params, buffers, fnet = net
|
||||
n_train_iter = db.x_train.shape[0] // db.batchsz
|
||||
|
||||
for batch_idx in range(n_train_iter):
|
||||
start_time = time.time()
|
||||
# Sample a batch of support and query images and labels.
|
||||
x_spt, y_spt, x_qry, y_qry = db.next()
|
||||
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
querysz = x_qry.size(1)
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
# doesn't have to be duplicated between `train` and `test`?
|
||||
|
||||
# Initialize the inner optimizer to adapt the parameters to
|
||||
# the support set.
|
||||
n_inner_iter = 5
|
||||
# inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
|
||||
|
||||
qry_losses = []
|
||||
qry_accs = []
|
||||
meta_opt.zero_grad()
|
||||
for i in range(task_num):
|
||||
# Optimize the likelihood of the support set by taking
|
||||
# gradient steps w.r.t. the model's parameters.
|
||||
# This adapts the model's meta-parameters to the task.
|
||||
new_params = params
|
||||
for _ in range(n_inner_iter):
|
||||
spt_logits = fnet(new_params, buffers, (x_spt[i],))
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params, create_graph=True)
|
||||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
|
||||
|
||||
# The final set of adapted parameters will induce some
|
||||
# final loss and accuracy on the query dataset.
|
||||
# These will be used to update the model's meta-parameters.
|
||||
qry_logits = fnet(new_params, buffers, (x_qry[i],))
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_acc = (qry_logits.argmax(
|
||||
dim=1) == y_qry[i]).sum().item() / querysz
|
||||
qry_accs.append(qry_acc)
|
||||
|
||||
# Update the model's meta-parameters to optimize the query
|
||||
# losses across all of the tasks sampled in this batch.
|
||||
# This unrolls through the gradient steps.
|
||||
qry_loss.backward()
|
||||
|
||||
meta_opt.step()
|
||||
qry_losses = sum(qry_losses) / task_num
|
||||
qry_accs = 100. * sum(qry_accs) / task_num
|
||||
i = epoch + float(batch_idx) / n_train_iter
|
||||
iter_time = time.time() - start_time
|
||||
if batch_idx % 4 == 0:
|
||||
print(
|
||||
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
|
||||
)
|
||||
|
||||
log.append({
|
||||
'epoch': i,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'train',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
def test(db, net, device, epoch, log):
|
||||
# Crucially in our testing procedure here, we do *not* fine-tune
|
||||
# the model during testing for simplicity.
|
||||
# Most research papers using MAML for this task do an extra
|
||||
# stage of fine-tuning here that should be added if you are
|
||||
# adapting this code for research.
|
||||
[params, buffers, fnet] = net
|
||||
n_test_iter = db.x_test.shape[0] // db.batchsz
|
||||
|
||||
qry_losses = []
|
||||
qry_accs = []
|
||||
|
||||
for batch_idx in range(n_test_iter):
|
||||
x_spt, y_spt, x_qry, y_qry = db.next('test')
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
# doesn't have to be duplicated between `train` and `test`?
|
||||
n_inner_iter = 5
|
||||
|
||||
for i in range(task_num):
|
||||
new_params = params
|
||||
for _ in range(n_inner_iter):
|
||||
spt_logits = fnet(new_params, buffers, (x_spt[i],))
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params)
|
||||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = fnet(new_params, buffers, (x_qry[i],)).detach()
|
||||
qry_loss = F.cross_entropy(
|
||||
qry_logits, y_qry[i], reduction='none')
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_accs.append(
|
||||
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
|
||||
qry_losses = torch.cat(qry_losses).mean().item()
|
||||
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
|
||||
print(
|
||||
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
|
||||
)
|
||||
log.append({
|
||||
'epoch': epoch + 1,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'test',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
def plot(log):
|
||||
# Generally you should pull your plotting code out of your training
|
||||
# script but we are doing it here for brevity.
|
||||
df = pd.DataFrame(log)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
train_df = df[df['mode'] == 'train']
|
||||
test_df = df[df['mode'] == 'test']
|
||||
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
|
||||
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Accuracy')
|
||||
ax.set_ylim(70, 100)
|
||||
fig.legend(ncol=2, loc='lower right')
|
||||
fig.tight_layout()
|
||||
fname = 'maml-accs.png'
|
||||
print(f'--- Plotting accuracy to {fname}')
|
||||
fig.savefig(fname)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# Won't need this after this PR is merged in:
|
||||
# https://github.com/pytorch/pytorch/pull/22245
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
276
functorch/examples/maml_omniglot/maml-omniglot-transforms.py
Executable file
276
functorch/examples/maml_omniglot/maml-omniglot-transforms.py
Executable file
@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
|
||||
for few-shot Omniglot classification.
|
||||
For more details see the original MAML paper:
|
||||
https://arxiv.org/abs/1703.03400
|
||||
|
||||
This code has been modified from Jackie Loong's PyTorch MAML implementation:
|
||||
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
|
||||
|
||||
Our MAML++ fork and experiments are available at:
|
||||
https://github.com/bamos/HowToTrainYourMAMLPytorch
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import typing
|
||||
import functools
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
plt.style.use('bmh')
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from functorch import make_functional_with_buffers, vmap, grad
|
||||
|
||||
import higher
|
||||
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
# torch._C._debug_only_display_vmap_fallback_warnings(True)
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--n_way', type=int, help='n way', default=5)
|
||||
argparser.add_argument(
|
||||
'--k_spt', type=int, help='k shot for support set', default=5)
|
||||
argparser.add_argument(
|
||||
'--k_qry', type=int, help='k shot for query set', default=15)
|
||||
argparser.add_argument(
|
||||
'--device', type=str, help='device', default='cuda')
|
||||
argparser.add_argument(
|
||||
'--task_num',
|
||||
type=int,
|
||||
help='meta batch size, namely task num',
|
||||
default=32)
|
||||
argparser.add_argument('--seed', type=int, help='random seed', default=1)
|
||||
args = argparser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Set up the Omniglot loader.
|
||||
device = args.device
|
||||
db = OmniglotNShot(
|
||||
'/tmp/omniglot-data',
|
||||
batchsz=args.task_num,
|
||||
n_way=args.n_way,
|
||||
k_shot=args.k_spt,
|
||||
k_query=args.k_qry,
|
||||
imgsz=28,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create a vanilla PyTorch neural network.
|
||||
# TODO: The prototype doesn't support in-place relu (and some other
|
||||
# in-place operations. That can be fixed.)
|
||||
inplace_relu = False
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(1, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Conv2d(64, 64, 3),
|
||||
nn.BatchNorm2d(64, momentum=1, affine=True),
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
|
||||
net.train()
|
||||
|
||||
# Given this module we've created, rip out the parameters and buffers
|
||||
# and return a functional version of the module. `fnet` is stateless
|
||||
# and can be called with `fnet(params, buffers, args, kwargs)`
|
||||
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
meta_opt = optim.Adam(params, lr=1e-3)
|
||||
|
||||
log = []
|
||||
for epoch in range(100):
|
||||
train(db, [params, buffers, fnet], device, meta_opt, epoch, log)
|
||||
test(db, [params, buffers, fnet], device, epoch, log)
|
||||
plot(log)
|
||||
|
||||
|
||||
# Trains a model for n_inner_iter using the support and returns a loss
|
||||
# using the query.
|
||||
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
|
||||
params, buffers, fnet = net
|
||||
querysz = x_qry.size(0)
|
||||
|
||||
def compute_loss(new_params, buffers, x, y):
|
||||
logits = fnet(new_params, buffers, (x,))
|
||||
loss = F.cross_entropy(logits, y)
|
||||
return loss
|
||||
|
||||
new_params = params
|
||||
for _ in range(n_inner_iter):
|
||||
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
|
||||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
|
||||
|
||||
# The final set of adapted parameters will induce some
|
||||
# final loss and accuracy on the query dataset.
|
||||
# These will be used to update the model's meta-parameters.
|
||||
qry_logits = fnet(new_params, buffers, (x_qry,))
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry)
|
||||
qry_acc = (qry_logits.argmax(
|
||||
dim=1) == y_qry).sum() / querysz
|
||||
|
||||
return qry_loss, qry_acc
|
||||
|
||||
|
||||
def train(db, net, device, meta_opt, epoch, log):
|
||||
params, buffers, fnet = net
|
||||
n_train_iter = db.x_train.shape[0] // db.batchsz
|
||||
|
||||
for batch_idx in range(n_train_iter):
|
||||
start_time = time.time()
|
||||
# Sample a batch of support and query images and labels.
|
||||
x_spt, y_spt, x_qry, y_qry = db.next()
|
||||
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
|
||||
n_inner_iter = 5
|
||||
meta_opt.zero_grad()
|
||||
|
||||
# In parallel, trains one model per task. There is a support (x, y)
|
||||
# for each task and a query (x, y) for each task.
|
||||
compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
|
||||
qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)
|
||||
|
||||
# Compute the maml loss by summing together the returned losses.
|
||||
qry_losses.sum().backward()
|
||||
|
||||
meta_opt.step()
|
||||
qry_losses = qry_losses.detach().sum() / task_num
|
||||
qry_accs = 100. * qry_accs.sum() / task_num
|
||||
i = epoch + float(batch_idx) / n_train_iter
|
||||
iter_time = time.time() - start_time
|
||||
if batch_idx % 4 == 0:
|
||||
print(
|
||||
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
|
||||
)
|
||||
|
||||
log.append({
|
||||
'epoch': i,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'train',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
def test(db, net, device, epoch, log):
|
||||
# Crucially in our testing procedure here, we do *not* fine-tune
|
||||
# the model during testing for simplicity.
|
||||
# Most research papers using MAML for this task do an extra
|
||||
# stage of fine-tuning here that should be added if you are
|
||||
# adapting this code for research.
|
||||
[params, buffers, fnet] = net
|
||||
n_test_iter = db.x_test.shape[0] // db.batchsz
|
||||
|
||||
qry_losses = []
|
||||
qry_accs = []
|
||||
|
||||
for batch_idx in range(n_test_iter):
|
||||
x_spt, y_spt, x_qry, y_qry = db.next('test')
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
# doesn't have to be duplicated between `train` and `test`?
|
||||
n_inner_iter = 5
|
||||
|
||||
for i in range(task_num):
|
||||
new_params = params
|
||||
for _ in range(n_inner_iter):
|
||||
spt_logits = fnet(new_params, buffers, (x_spt[i],))
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params)
|
||||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = fnet(new_params, buffers, (x_qry[i],)).detach()
|
||||
qry_loss = F.cross_entropy(
|
||||
qry_logits, y_qry[i], reduction='none')
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_accs.append(
|
||||
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
|
||||
qry_losses = torch.cat(qry_losses).mean().item()
|
||||
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
|
||||
print(
|
||||
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
|
||||
)
|
||||
log.append({
|
||||
'epoch': epoch + 1,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'test',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
def plot(log):
|
||||
# Generally you should pull your plotting code out of your training
|
||||
# script but we are doing it here for brevity.
|
||||
df = pd.DataFrame(log)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
train_df = df[df['mode'] == 'train']
|
||||
test_df = df[df['mode'] == 'test']
|
||||
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
|
||||
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Accuracy')
|
||||
ax.set_ylim(70, 100)
|
||||
fig.legend(ncol=2, loc='lower right')
|
||||
fig.tight_layout()
|
||||
fname = 'maml-accs.png'
|
||||
print(f'--- Plotting accuracy to {fname}')
|
||||
fig.savefig(fname)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# Won't need this after this PR is merged in:
|
||||
# https://github.com/pytorch/pytorch/pull/22245
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
303
functorch/examples/maml_omniglot/support/omniglot_loaders.py
Normal file
303
functorch/examples/maml_omniglot/support/omniglot_loaders.py
Normal file
@ -0,0 +1,303 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation:
|
||||
# https://github.com/dragen1860/MAML-Pytorch
|
||||
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
|
||||
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import os.path
|
||||
import errno
|
||||
|
||||
|
||||
class Omniglot(data.Dataset):
|
||||
urls = [
|
||||
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
|
||||
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
|
||||
]
|
||||
raw_folder = 'raw'
|
||||
processed_folder = 'processed'
|
||||
training_file = 'training.pt'
|
||||
test_file = 'test.pt'
|
||||
|
||||
'''
|
||||
The items are (filename,category). The index of all the categories can be found in self.idx_classes
|
||||
Args:
|
||||
- root: the directory where the dataset will be stored
|
||||
- transform: how to transform the input
|
||||
- target_transform: how to transform the target
|
||||
- download: need to download the dataset
|
||||
'''
|
||||
|
||||
def __init__(self, root, transform=None, target_transform=None,
|
||||
download=False):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
if not self._check_exists():
|
||||
if download:
|
||||
self.download()
|
||||
else:
|
||||
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
|
||||
|
||||
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
|
||||
self.idx_classes = index_classes(self.all_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.all_items[index][0]
|
||||
img = str.join('/', [self.all_items[index][2], filename])
|
||||
|
||||
target = self.idx_classes[self.all_items[index][1]]
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_items)
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
|
||||
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
|
||||
|
||||
def download(self):
|
||||
from six.moves import urllib
|
||||
import zipfile
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
# download files
|
||||
try:
|
||||
os.makedirs(os.path.join(self.root, self.raw_folder))
|
||||
os.makedirs(os.path.join(self.root, self.processed_folder))
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
for url in self.urls:
|
||||
print('== Downloading ' + url)
|
||||
data = urllib.request.urlopen(url)
|
||||
filename = url.rpartition('/')[2]
|
||||
file_path = os.path.join(self.root, self.raw_folder, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(data.read())
|
||||
file_processed = os.path.join(self.root, self.processed_folder)
|
||||
print("== Unzip from " + file_path + " to " + file_processed)
|
||||
zip_ref = zipfile.ZipFile(file_path, 'r')
|
||||
zip_ref.extractall(file_processed)
|
||||
zip_ref.close()
|
||||
print("Download finished.")
|
||||
|
||||
|
||||
def find_classes(root_dir):
|
||||
retour = []
|
||||
for (root, dirs, files) in os.walk(root_dir):
|
||||
for f in files:
|
||||
if (f.endswith("png")):
|
||||
r = root.split('/')
|
||||
lr = len(r)
|
||||
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
|
||||
print("== Found %d items " % len(retour))
|
||||
return retour
|
||||
|
||||
|
||||
def index_classes(items):
|
||||
idx = {}
|
||||
for i in items:
|
||||
if i[1] not in idx:
|
||||
idx[i[1]] = len(idx)
|
||||
print("== Found %d classes" % len(idx))
|
||||
return idx
|
||||
|
||||
|
||||
class OmniglotNShot:
|
||||
|
||||
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
|
||||
"""
|
||||
Different from mnistNShot, the
|
||||
:param root:
|
||||
:param batchsz: task num
|
||||
:param n_way:
|
||||
:param k_shot:
|
||||
:param k_qry:
|
||||
:param imgsz:
|
||||
"""
|
||||
|
||||
self.resize = imgsz
|
||||
self.device = device
|
||||
if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
|
||||
# if root/data.npy does not exist, just download it
|
||||
self.x = Omniglot(
|
||||
root, download=True,
|
||||
transform=transforms.Compose(
|
||||
[lambda x: Image.open(x).convert('L'),
|
||||
lambda x: x.resize((imgsz, imgsz)),
|
||||
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
|
||||
lambda x: np.transpose(x, [2, 0, 1]),
|
||||
lambda x: x/255.]),
|
||||
)
|
||||
|
||||
temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
|
||||
for (img, label) in self.x:
|
||||
if label in temp.keys():
|
||||
temp[label].append(img)
|
||||
else:
|
||||
temp[label] = [img]
|
||||
|
||||
self.x = []
|
||||
for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs
|
||||
self.x.append(np.array(imgs))
|
||||
|
||||
# as different class may have different number of imgs
|
||||
self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total]
|
||||
# each character contains 20 imgs
|
||||
print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
|
||||
temp = [] # Free memory
|
||||
# save all dataset into npy file.
|
||||
np.save(os.path.join(root, 'omniglot.npy'), self.x)
|
||||
print('write into omniglot.npy.')
|
||||
else:
|
||||
# if data.npy exists, just load it.
|
||||
self.x = np.load(os.path.join(root, 'omniglot.npy'))
|
||||
print('load from omniglot.npy.')
|
||||
|
||||
# [1623, 20, 84, 84, 1]
|
||||
# TODO: can not shuffle here, we must keep training and test set distinct!
|
||||
self.x_train, self.x_test = self.x[:1200], self.x[1200:]
|
||||
|
||||
# self.normalization()
|
||||
|
||||
self.batchsz = batchsz
|
||||
self.n_cls = self.x.shape[0] # 1623
|
||||
self.n_way = n_way # n way
|
||||
self.k_shot = k_shot # k shot
|
||||
self.k_query = k_query # k query
|
||||
assert (k_shot + k_query) <=20
|
||||
|
||||
# save pointer of current read batch in total cache
|
||||
self.indexes = {"train": 0, "test": 0}
|
||||
self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached
|
||||
print("DB: train", self.x_train.shape, "test", self.x_test.shape)
|
||||
|
||||
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached
|
||||
"test": self.load_data_cache(self.datasets["test"])}
|
||||
|
||||
def normalization(self):
|
||||
"""
|
||||
Normalizes our data, to have a mean of 0 and sdt of 1
|
||||
"""
|
||||
self.mean = np.mean(self.x_train)
|
||||
self.std = np.std(self.x_train)
|
||||
self.max = np.max(self.x_train)
|
||||
self.min = np.min(self.x_train)
|
||||
# print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
|
||||
self.x_train = (self.x_train - self.mean) / self.std
|
||||
self.x_test = (self.x_test - self.mean) / self.std
|
||||
|
||||
self.mean = np.mean(self.x_train)
|
||||
self.std = np.std(self.x_train)
|
||||
self.max = np.max(self.x_train)
|
||||
self.min = np.min(self.x_train)
|
||||
|
||||
# print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
|
||||
|
||||
def load_data_cache(self, data_pack):
|
||||
"""
|
||||
Collects several batches data for N-shot learning
|
||||
:param data_pack: [cls_num, 20, 84, 84, 1]
|
||||
:return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
|
||||
"""
|
||||
# take 5 way 1 shot as example: 5 * 1
|
||||
setsz = self.k_shot * self.n_way
|
||||
querysz = self.k_query * self.n_way
|
||||
data_cache = []
|
||||
|
||||
# print('preload next 50 caches of batchsz of batch.')
|
||||
for sample in range(10): # num of episodes
|
||||
|
||||
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
|
||||
for i in range(self.batchsz): # one batch means one set
|
||||
|
||||
x_spt, y_spt, x_qry, y_qry = [], [], [], []
|
||||
selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
|
||||
|
||||
for j, cur_class in enumerate(selected_cls):
|
||||
|
||||
selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
|
||||
|
||||
# meta-training and meta-test
|
||||
x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
|
||||
x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
|
||||
y_spt.append([j for _ in range(self.k_shot)])
|
||||
y_qry.append([j for _ in range(self.k_query)])
|
||||
|
||||
# shuffle inside a batch
|
||||
perm = np.random.permutation(self.n_way * self.k_shot)
|
||||
x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
|
||||
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
|
||||
perm = np.random.permutation(self.n_way * self.k_query)
|
||||
x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
|
||||
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
|
||||
|
||||
# append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
|
||||
x_spts.append(x_spt)
|
||||
y_spts.append(y_spt)
|
||||
x_qrys.append(x_qry)
|
||||
y_qrys.append(y_qry)
|
||||
|
||||
|
||||
# [b, setsz, 1, 84, 84]
|
||||
x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
|
||||
y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz)
|
||||
# [b, qrysz, 1, 84, 84]
|
||||
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
|
||||
y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz)
|
||||
|
||||
x_spts, y_spts, x_qrys, y_qrys = [
|
||||
torch.from_numpy(z).to(self.device) for z in
|
||||
[x_spts, y_spts, x_qrys, y_qrys]
|
||||
]
|
||||
|
||||
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
|
||||
|
||||
return data_cache
|
||||
|
||||
def next(self, mode='train'):
|
||||
"""
|
||||
Gets next batch from the dataset with name.
|
||||
:param mode: The name of the splitting (one of "train", "val", "test")
|
||||
:return:
|
||||
"""
|
||||
# update cache if indexes is larger cached num
|
||||
if self.indexes[mode] >= len(self.datasets_cache[mode]):
|
||||
self.indexes[mode] = 0
|
||||
self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
|
||||
|
||||
next_batch = self.datasets_cache[mode][self.indexes[mode]]
|
||||
self.indexes[mode] += 1
|
||||
|
||||
return next_batch
|
113
functorch/examples/maml_regression/evjang.py
Normal file
113
functorch/examples/maml_regression/evjang.py
Normal file
@ -0,0 +1,113 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def net(x, params):
|
||||
x = F.linear(x, params[0], params[1])
|
||||
x = F.relu(x)
|
||||
|
||||
x = F.linear(x, params[2], params[3])
|
||||
x = F.relu(x)
|
||||
|
||||
x = F.linear(x, params[4], params[5])
|
||||
return x
|
||||
|
||||
params = [
|
||||
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(1).zero_().requires_grad_(),
|
||||
]
|
||||
|
||||
opt = torch.optim.Adam(params, lr=1e-3)
|
||||
alpha = 0.1
|
||||
|
||||
K = 20
|
||||
losses = []
|
||||
num_tasks = 4
|
||||
def sample_tasks(outer_batch_size, inner_batch_size):
|
||||
# Select amplitude and phase for the task
|
||||
As = []
|
||||
phases = []
|
||||
for _ in range(outer_batch_size):
|
||||
As.append(np.random.uniform(low=0.1, high=.5))
|
||||
phases.append(np.random.uniform(low=0., high=np.pi))
|
||||
def get_batch():
|
||||
xs, ys = [], []
|
||||
for A, phase in zip(As, phases):
|
||||
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
||||
y = A * np.sin(x + phase)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
||||
x1, y1 = get_batch()
|
||||
x2, y2 = get_batch()
|
||||
return x1, y1, x2, y2
|
||||
|
||||
for it in range(20000):
|
||||
loss2 = 0.0
|
||||
opt.zero_grad()
|
||||
def get_loss_for_task(x1, y1, x2, y2):
|
||||
f = net(x1, params)
|
||||
loss = F.mse_loss(f, y1)
|
||||
|
||||
# create_graph=True because computing grads here is part of the forward pass.
|
||||
# We want to differentiate through the SGD update steps and get higher order
|
||||
# derivatives in the backward pass.
|
||||
grads = torch.autograd.grad(loss, params, create_graph=True)
|
||||
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
|
||||
|
||||
v_f = net(x2, new_params)
|
||||
return F.mse_loss(v_f, y2)
|
||||
|
||||
task = sample_tasks(num_tasks, K)
|
||||
inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)]
|
||||
loss2 = sum(inner_losses)/len(inner_losses)
|
||||
loss2.backward()
|
||||
|
||||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
||||
losses.append(loss2)
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
|
||||
|
||||
t_x = torch.empty(4, 1).uniform_(-5, 5)
|
||||
t_y = t_A*torch.sin(t_x + t_b)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
t_params = params
|
||||
for k in range(5):
|
||||
t_f = net(t_x, t_params)
|
||||
t_loss = F.l1_loss(t_f, t_y)
|
||||
|
||||
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
|
||||
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
|
||||
|
||||
|
||||
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
|
||||
test_y = t_A*torch.sin(test_x + t_b)
|
||||
|
||||
test_f = net(test_x, t_params)
|
||||
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
||||
plt.legend()
|
||||
plt.savefig('maml-sine.png')
|
||||
plt.figure()
|
||||
plt.plot(np.convolve(losses, [.05]*20))
|
||||
plt.savefig('losses.png')
|
118
functorch/examples/maml_regression/evjang_transforms.py
Normal file
118
functorch/examples/maml_regression/evjang_transforms.py
Normal file
@ -0,0 +1,118 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from functorch import grad, vmap
|
||||
|
||||
def net(params, x):
|
||||
x = F.linear(x, params[0], params[1])
|
||||
x = F.relu(x)
|
||||
|
||||
x = F.linear(x, params[2], params[3])
|
||||
x = F.relu(x)
|
||||
|
||||
x = F.linear(x, params[4], params[5])
|
||||
return x
|
||||
|
||||
params = [
|
||||
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(1).zero_().requires_grad_(),
|
||||
]
|
||||
|
||||
# The prototype doesn't like F.mse_loss.
|
||||
def mse_loss(x, y):
|
||||
return torch.mean((x - y) ** 2)
|
||||
|
||||
opt = torch.optim.Adam(params, lr=1e-3)
|
||||
alpha = 0.1
|
||||
|
||||
K = 20
|
||||
losses = []
|
||||
num_tasks = 4
|
||||
def sample_tasks(outer_batch_size, inner_batch_size):
|
||||
# Select amplitude and phase for the task
|
||||
As = []
|
||||
phases = []
|
||||
for _ in range(outer_batch_size):
|
||||
As.append(np.random.uniform(low=0.1, high=.5))
|
||||
phases.append(np.random.uniform(low=0., high=np.pi))
|
||||
def get_batch():
|
||||
xs, ys = [], []
|
||||
for A, phase in zip(As, phases):
|
||||
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
||||
y = A * np.sin(x + phase)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
||||
x1, y1 = get_batch()
|
||||
x2, y2 = get_batch()
|
||||
return x1, y1, x2, y2
|
||||
|
||||
for it in range(20000):
|
||||
loss2 = 0.0
|
||||
opt.zero_grad()
|
||||
def get_loss_for_task(x1, y1, x2, y2):
|
||||
def inner_loss(params, x1, y1):
|
||||
f = net(params, x1)
|
||||
loss = mse_loss(f, y1)
|
||||
return loss
|
||||
|
||||
grads = grad(inner_loss)(tuple(params), x1, y1)
|
||||
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
|
||||
|
||||
v_f = net(new_params, x2)
|
||||
return mse_loss(v_f, y2)
|
||||
|
||||
task = sample_tasks(num_tasks, K)
|
||||
inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
|
||||
loss2 = sum(inner_losses)/len(inner_losses)
|
||||
loss2.backward()
|
||||
|
||||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
||||
losses.append(loss2)
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
|
||||
|
||||
t_x = torch.empty(4, 1).uniform_(-5, 5)
|
||||
t_y = t_A*torch.sin(t_x + t_b)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
t_params = params
|
||||
for k in range(5):
|
||||
t_f = net(t_x, t_params)
|
||||
t_loss = F.l1_loss(t_f, t_y)
|
||||
|
||||
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
|
||||
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
|
||||
|
||||
|
||||
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
|
||||
test_y = t_A*torch.sin(test_x + t_b)
|
||||
|
||||
test_f = net(test_x, t_params)
|
||||
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
||||
plt.legend()
|
||||
plt.savefig('maml-sine.png')
|
||||
plt.figure()
|
||||
plt.plot(np.convolve(losses, [.05]*20))
|
||||
plt.savefig('losses.png')
|
115
functorch/examples/maml_regression/evjang_transforms_module.py
Normal file
115
functorch/examples/maml_regression/evjang_transforms_module.py
Normal file
@ -0,0 +1,115 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from functorch import grad, vmap, make_functional
|
||||
|
||||
class ThreeLayerNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ThreeLayerNet, self).__init__()
|
||||
self.fc1 = nn.Linear(1, 40)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.fc2 = nn.Linear(40, 40)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.fc3 = nn.Linear(40, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
# The prototype doesn't like F.mse_loss.
|
||||
def mse_loss(x, y):
|
||||
return torch.mean((x - y) ** 2)
|
||||
|
||||
params, net, _ = make_functional(ThreeLayerNet())
|
||||
opt = torch.optim.Adam(params, lr=1e-3)
|
||||
alpha = 0.1
|
||||
|
||||
K = 20
|
||||
losses = []
|
||||
num_tasks = 4
|
||||
def sample_tasks(outer_batch_size, inner_batch_size):
|
||||
# Select amplitude and phase for the task
|
||||
As = []
|
||||
phases = []
|
||||
for _ in range(outer_batch_size):
|
||||
As.append(np.random.uniform(low=0.1, high=.5))
|
||||
phases.append(np.random.uniform(low=0., high=np.pi))
|
||||
def get_batch():
|
||||
xs, ys = [], []
|
||||
for A, phase in zip(As, phases):
|
||||
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
||||
y = A * np.sin(x + phase)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
||||
x1, y1 = get_batch()
|
||||
x2, y2 = get_batch()
|
||||
return x1, y1, x2, y2
|
||||
|
||||
for it in range(20000):
|
||||
loss2 = 0.0
|
||||
opt.zero_grad()
|
||||
def get_loss_for_task(x1, y1, x2, y2):
|
||||
def inner_loss(params, x1, y1):
|
||||
f = net(params, (x1,))
|
||||
loss = mse_loss(f, y1)
|
||||
return loss
|
||||
|
||||
grads = grad(inner_loss)(params, x1, y1)
|
||||
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
|
||||
|
||||
v_f = net(new_params, (x2,))
|
||||
return mse_loss(v_f, y2)
|
||||
|
||||
task = sample_tasks(num_tasks, K)
|
||||
inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
|
||||
loss2 = sum(inner_losses)/len(inner_losses)
|
||||
loss2.backward()
|
||||
|
||||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
||||
losses.append(loss2)
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
|
||||
|
||||
t_x = torch.empty(4, 1).uniform_(-5, 5)
|
||||
t_y = t_A*torch.sin(t_x + t_b)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
t_params = params
|
||||
for k in range(5):
|
||||
t_f = net(t_params, (t_x,))
|
||||
t_loss = F.l1_loss(t_f, t_y)
|
||||
|
||||
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
|
||||
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
|
||||
|
||||
|
||||
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
|
||||
test_y = t_A*torch.sin(test_x + t_b)
|
||||
|
||||
test_f = net(t_params, (test_x,))
|
||||
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
||||
plt.legend()
|
||||
plt.savefig('maml-sine.png')
|
||||
plt.figure()
|
||||
plt.plot(np.convolve(losses, [.05]*20))
|
||||
plt.savefig('losses.png')
|
6
functorch/functorch/__init__.py
Normal file
6
functorch/functorch/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
import torch
|
||||
from . import _C
|
||||
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
0
functorch/functorch/_src/__init__.py
Normal file
0
functorch/functorch/_src/__init__.py
Normal file
185
functorch/functorch/_src/eager_transforms.py
Normal file
185
functorch/functorch/_src/eager_transforms.py
Normal file
@ -0,0 +1,185 @@
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.make_functional import make_functional, make_functional_with_buffers
|
||||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
|
||||
from functorch._C import (
|
||||
_wrap_for_grad,
|
||||
_unwrap_for_grad,
|
||||
_grad_increment_nesting,
|
||||
_grad_decrement_nesting,
|
||||
)
|
||||
|
||||
# x = torch.ones(2, 3)
|
||||
# y = torch.ones(2, 3)
|
||||
# # result = vmap(torch.add)(x, y)
|
||||
# result = vmap(vmap(torch.add))(x, y)
|
||||
|
||||
# assert torch.allclose(result, x + y)
|
||||
|
||||
# TODO: replace all of these with pytrees
|
||||
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
aliased = tensor
|
||||
return aliased.requires_grad_()
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
raise ValueError(f'Thing passed to transform API must be Tensor, List or Tuple, '
|
||||
f'got {type(tensor_or_tuple_of_tensors)}')
|
||||
|
||||
def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
return _unwrap_for_grad(tensor, level)
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
assert False
|
||||
|
||||
def _any_differentiable(tensor_or_tuple_of_tensors):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
return tensor.requires_grad
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
|
||||
return False
|
||||
|
||||
def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
return _wrap_for_grad(tensor, level)
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
|
||||
return tensor_or_tuple_of_tensors
|
||||
|
||||
# How do we increment and decrement the nesting? I don't think we can.
|
||||
def vjp(f, *primals):
|
||||
level = _grad_increment_nesting()
|
||||
try:
|
||||
primals = _wrap_all_tensors(primals, level)
|
||||
diff_primals = _create_differentiable(primals, level)
|
||||
primals_out = f(*diff_primals)
|
||||
results = _undo_create_differentiable(primals_out, level)
|
||||
|
||||
def wrapper(*cotangents, retain_graph=True, create_graph=True):
|
||||
result = torch.autograd.grad(primals_out, diff_primals, cotangents,
|
||||
retain_graph=retain_graph, create_graph=create_graph)
|
||||
return result
|
||||
|
||||
finally:
|
||||
_grad_decrement_nesting()
|
||||
|
||||
return results, wrapper
|
||||
|
||||
def jacrev(f):
|
||||
def wrapper_fn(primal):
|
||||
output, vjp_fn = vjp(f, primal)
|
||||
assert isinstance(output, torch.Tensor)
|
||||
# TODO: does jacrev compose with vmap...? the eye call should make it so that it doesn't
|
||||
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device) \
|
||||
.view(output.numel(), *output.shape)
|
||||
result, = vmap(vjp_fn)(basis)
|
||||
result = result.view(*output.shape, *primal.shape)
|
||||
return result
|
||||
return wrapper_fn
|
||||
|
||||
#
|
||||
#
|
||||
# def jacrev(f, diff_argnums=(0,)):
|
||||
# def wrapper(*args):
|
||||
# torch._C._grad_increment_nesting()
|
||||
# output = None
|
||||
# grad_outputs = None
|
||||
# try:
|
||||
# args = [_create_differentiable(arg) if i in diff_argnums else arg
|
||||
# for i, arg in enumerate(args)]
|
||||
# output = f(*args)
|
||||
# # Only support single tensor output for now
|
||||
# assert isinstance(output, torch.Tensor)
|
||||
# output_numel = output.numel()
|
||||
# if output_numel != 0:
|
||||
# grad_output = torch.eye(output_numel).view(output_numel, *output.shape)
|
||||
#
|
||||
# diff_args = [args[i] for i in diff_argnums]
|
||||
# single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
|
||||
# # TODO: quick hack...
|
||||
# if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
|
||||
# diff_args = diff_args[0]
|
||||
# # NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
#
|
||||
# def compute_vjp(v):
|
||||
# return torch.autograd.grad(output, diff_args, v, create_graph=True)
|
||||
#
|
||||
# if output_numel == 0:
|
||||
# grad_input = compute_vjp(grad_output)
|
||||
# else:
|
||||
# grad_input = vmap(compute_vjp)(grad_output)
|
||||
#
|
||||
# if single_diff_arg:
|
||||
# grad_input = grad_input[0]
|
||||
# finally:
|
||||
# _undo_create_differentiable(args)
|
||||
# torch._C._grad_decrement_nesting()
|
||||
# return grad_input, output
|
||||
# return wrapper
|
||||
|
||||
def grad_with_value(f, diff_argnums=(0,), has_aux=False):
|
||||
def wrapper(*args):
|
||||
level = _grad_increment_nesting()
|
||||
output, aux, grad_input = None, None, None
|
||||
try:
|
||||
args = _wrap_all_tensors(args, level)
|
||||
args = [_create_differentiable(arg, level) if i in diff_argnums else arg
|
||||
for i, arg in enumerate(args)]
|
||||
# print("calling f(*args)")
|
||||
output = f(*args)
|
||||
# print("done with f(*args)")
|
||||
if has_aux:
|
||||
output, aux = output
|
||||
# print("calling output.dim()")
|
||||
assert output.dim() == 0
|
||||
diff_args = [args[i] for i in diff_argnums]
|
||||
single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
|
||||
# TODO: quick hack...
|
||||
if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
|
||||
diff_args = diff_args[0]
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
# import torchviz; import graphviz
|
||||
# graph = torchviz.make_dot(output)
|
||||
# graph.save("inner.dot")
|
||||
# print("calling autograd.grad")
|
||||
grad_input = torch.autograd.grad(
|
||||
output, diff_args, create_graph=True)
|
||||
# print("done-ish!")
|
||||
if single_diff_arg:
|
||||
grad_input = grad_input[0]
|
||||
finally:
|
||||
if grad_input is not None:
|
||||
grad_input = _undo_create_differentiable(grad_input, level)
|
||||
_grad_decrement_nesting()
|
||||
if has_aux:
|
||||
return grad_input, output, aux
|
||||
return grad_input, output
|
||||
return wrapper
|
||||
|
||||
def grad(f, diff_argnums=(0,), has_aux=False):
|
||||
def wrapper(*args):
|
||||
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
|
||||
if has_aux:
|
||||
return results[0], results[2]
|
||||
return results[0]
|
||||
return wrapper
|
||||
|
99
functorch/functorch/_src/make_functional.py
Normal file
99
functorch/functorch/_src/make_functional.py
Normal file
@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'])
|
||||
"""
|
||||
if len(names) == 1:
|
||||
delattr(obj, names[0])
|
||||
else:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
||||
"""
|
||||
if len(names) == 1:
|
||||
setattr(obj, names[0], value)
|
||||
else:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
The weights must be re-loaded with `load_weights` before the model
|
||||
can be used again.
|
||||
Note that this function modifies the model in place and after this
|
||||
call, mod.parameters() will be empty.
|
||||
"""
|
||||
orig_params = tuple(mod.parameters())
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_parameters()):
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
names.append(name)
|
||||
|
||||
# Make params regular Tensors instead of nn.Parameter
|
||||
params = tuple(p for p in orig_params)
|
||||
return params, names
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
as Tensors. This means that mod.parameters() will still be empty after this call.
|
||||
"""
|
||||
for name, p in zip(names, params):
|
||||
if as_params:
|
||||
p = nn.Parameter(p)
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
orig_params = tuple(mod.buffers())
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_buffers()):
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
names.append(name)
|
||||
|
||||
# Make params regular Tensors instead of nn.Parameter
|
||||
params = tuple(p for p in orig_params)
|
||||
return params, names
|
||||
|
||||
def load_buffers(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
||||
for name, p in zip(names, params):
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
mutable_model = copy.deepcopy(model)
|
||||
load_weights(mutable_model, descriptors, weights)
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, fun, descriptors
|
||||
|
||||
def make_functional_with_buffers(model: nn.Module):
|
||||
weights, weight_descriptors = extract_weights(model)
|
||||
buffers, buf_descriptors = extract_buffers(model)
|
||||
|
||||
def fun(weights, buffers, data):
|
||||
mutable_model = copy.deepcopy(model)
|
||||
load_weights(mutable_model, weight_descriptors, weights)
|
||||
load_buffers(mutable_model, buf_descriptors, buffers)
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, buffers, fun, weight_descriptors, buf_descriptors
|
270
functorch/functorch/_src/vmap.py
Normal file
270
functorch/functorch/_src/vmap.py
Normal file
@ -0,0 +1,270 @@
|
||||
import torch
|
||||
import functools
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, Optional, Tuple, Union, List
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten
|
||||
import warnings
|
||||
|
||||
from functorch._C import (
|
||||
_add_batch_dim,
|
||||
_remove_batch_dim,
|
||||
_vmapmode_decrement_nesting,
|
||||
_vmapmode_increment_nesting,
|
||||
)
|
||||
|
||||
in_dims_t = Union[int, Tuple]
|
||||
out_dims_t = Union[int, Tuple[int, ...]]
|
||||
|
||||
# Checks that all args-to-be-batched have the same batch dim size
|
||||
def _validate_and_get_batch_size(
|
||||
flat_in_dims: List[Optional[int]],
|
||||
flat_args: List) -> int:
|
||||
batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
|
||||
if in_dim is not None]
|
||||
if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]):
|
||||
raise ValueError(
|
||||
f'vmap: Expected all tensors to have the same size in the mapped '
|
||||
f'dimension, got sizes {batch_sizes} for the mapped dimension')
|
||||
return batch_sizes[0]
|
||||
|
||||
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
|
||||
if isinstance(batched_outputs, tuple):
|
||||
return len(batched_outputs)
|
||||
return 1
|
||||
|
||||
# If value is a tuple, check it has length `num_elements`.
|
||||
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
|
||||
def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
|
||||
if not isinstance(value, tuple):
|
||||
return (value,) * num_elements
|
||||
if len(value) != num_elements:
|
||||
raise ValueError(error_message_lambda())
|
||||
return value
|
||||
|
||||
# Creates BatchedTensors for every Tensor in arg that should be batched.
|
||||
# Returns the (potentially) batched arguments and the batch_size.
|
||||
def _create_batched_inputs(
|
||||
in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]:
|
||||
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'expected `in_dims` to be int or a (potentially nested) tuple '
|
||||
f'matching the structure of inputs, got: {type(in_dims)}.')
|
||||
if len(args) == 0:
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
|
||||
f'inputs, or you are trying to vmap over a function with no inputs. '
|
||||
f'The latter is unsupported.')
|
||||
|
||||
flat_args, args_spec = tree_flatten(args)
|
||||
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
|
||||
if flat_in_dims is None:
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'in_dims is not compatible with the structure of `inputs`. '
|
||||
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
|
||||
f'has structure {args_spec}.')
|
||||
|
||||
for arg, in_dim in zip(flat_args, flat_in_dims):
|
||||
if not isinstance(in_dim, int) and in_dim is not None:
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'Got in_dim={in_dim} for an input but in_dim must be either '
|
||||
f'an integer dimension or None.')
|
||||
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'Got in_dim={in_dim} for an input but the input is of type '
|
||||
f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
|
||||
f'please use None as the respective in_dim')
|
||||
if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
|
||||
f'Got in_dim={in_dim} for some input, but that input is a Tensor '
|
||||
f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
|
||||
f'0 <= in_dim < {arg.dim()}.')
|
||||
|
||||
batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
|
||||
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
batched_inputs = [arg if in_dim is None else
|
||||
_add_batch_dim(arg, in_dim, vmap_level) # type: ignore
|
||||
for in_dim, arg in zip(flat_in_dims, flat_args)]
|
||||
return tree_unflatten(batched_inputs, args_spec), batch_size
|
||||
|
||||
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
|
||||
def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
|
||||
num_outputs = _num_outputs(batched_outputs)
|
||||
out_dims_as_tuple = _as_tuple(
|
||||
out_dims, num_outputs,
|
||||
lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must '
|
||||
f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.')
|
||||
|
||||
# NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
# There is something wrong with our type bindings for functions that begin
|
||||
# with '_', see #40397.
|
||||
if isinstance(batched_outputs, Tensor):
|
||||
out_dim = out_dims_as_tuple[0]
|
||||
return _remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore
|
||||
return tuple(_remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore
|
||||
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
|
||||
|
||||
# Checks that `fn` returned one or more Tensors and nothing else.
|
||||
# NB: A python function that return multiple arguments returns a single tuple,
|
||||
# so we are effectively checking that `outputs` is a single Tensor or a tuple of
|
||||
# Tensors.
|
||||
def _validate_outputs(outputs: Any, func: Callable) -> None:
|
||||
if isinstance(outputs, Tensor):
|
||||
return
|
||||
if not isinstance(outputs, tuple):
|
||||
raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
|
||||
f'Tensors, got type {type(outputs)} as the return.')
|
||||
for idx, output in enumerate(outputs):
|
||||
if isinstance(output, Tensor):
|
||||
continue
|
||||
raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
|
||||
f'Tensors, got type {type(output)} for return {idx}.')
|
||||
|
||||
def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
|
||||
if isinstance(out_dims, int):
|
||||
return
|
||||
if not isinstance(out_dims, tuple) or \
|
||||
not all([isinstance(out_dim, int) for out_dim in out_dims]):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
|
||||
f'an int or a tuple of int representing where in the outputs the '
|
||||
f'vmapped dimension should appear.')
|
||||
|
||||
def _get_name(func: Callable):
|
||||
if hasattr(func, '__name__'):
|
||||
return func.__name__
|
||||
|
||||
# Not all callables have __name__, in fact, only static functions/methods do.
|
||||
# A callable created via functools.partial or an nn.Module, to name some
|
||||
# examples, don't have a __name__.
|
||||
return repr(func)
|
||||
|
||||
# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
|
||||
# sends those into func, and then unwraps the output BatchedTensors. Operations
|
||||
# on BatchedTensors perform the batched operations that the user is asking for.
|
||||
def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
||||
"""
|
||||
vmap is the vectorizing map. Returns a new function that maps `func` over some
|
||||
dimension of the inputs. Semantically, vmap pushes the map into PyTorch
|
||||
operations called by `func`, effectively vectorizing those operations.
|
||||
|
||||
vmap is useful for handling batch dimensions: one can write a function `func`
|
||||
that runs on examples and then lift it to a function that can take batches of
|
||||
examples with `vmap(func)`. vmap can also be used to compute batched
|
||||
gradients when composed with autograd.
|
||||
|
||||
.. warning::
|
||||
functorch.vmap is an experimental prototype that is subject to
|
||||
change and/or deletion. Please use at your own risk.
|
||||
|
||||
.. note::
|
||||
If you're interested in using vmap for your use case, please
|
||||
`contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
|
||||
We're interested in gathering feedback from early adopters to inform
|
||||
the design.
|
||||
|
||||
Args:
|
||||
func (function): A Python function that takes one or more arguments.
|
||||
Must return one or more Tensors.
|
||||
in_dims (int or nested structure): Specifies which dimension of the
|
||||
inputs should be mapped over. `in_dims` should have a structure
|
||||
like the inputs. If the `in_dim` for a particular input is None,
|
||||
then that indicates there is no map dimension. Default: 0.
|
||||
out_dims (int or Tuple[int]): Specifies where the mapped dimension
|
||||
should appear in the outputs. If `out_dims` is a Tuple, then it should
|
||||
have one element per output. Default: 0.
|
||||
|
||||
Returns:
|
||||
Returns a new "batched" function. It takes the same inputs as `func`,
|
||||
except each input has an extra dimension at the index specified by `in_dims`.
|
||||
It takes returns the same outputs as `func`, except each output has
|
||||
an extra dimension at the index specified by `out_dims`.
|
||||
|
||||
.. warning:
|
||||
vmap works best with functional-style code. Please do not perform any
|
||||
side-effects in `func`, with the exception of in-place PyTorch operations.
|
||||
Examples of side-effects include mutating Python data structures and
|
||||
assigning values to variables not captured in `func`.
|
||||
|
||||
One example of using `vmap` is to compute batched dot products. PyTorch
|
||||
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
|
||||
rummaging through docs, use `vmap` to construct a new function.
|
||||
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> batched_dot = functorch.vmap(torch.dot) # [N, D], [N, D] -> [N]
|
||||
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
||||
>>> batched_dot(x, y)
|
||||
|
||||
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
|
||||
model authoring experience.
|
||||
|
||||
>>> batch_size, feature_size = 3, 5
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>>
|
||||
>>> def model(feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
>>> result = functorch.vmap(model)(examples)
|
||||
|
||||
`vmap` can also help vectorize computations that were previously difficult
|
||||
or impossible to batch. One example is higher-order gradient computation.
|
||||
The PyTorch autograd engine computes vjps (vector-Jacobian products).
|
||||
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
|
||||
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
|
||||
we can vectorize the whole computation, computing the Jacobian in a single
|
||||
call to `autograd.grad`.
|
||||
|
||||
>>> # Setup
|
||||
>>> N = 5
|
||||
>>> f = lambda x: x ** 2
|
||||
>>> x = torch.randn(N, requires_grad=True)
|
||||
>>> y = f(x)
|
||||
>>> I_N = torch.eye(N)
|
||||
>>>
|
||||
>>> # Sequential approach
|
||||
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
|
||||
>>> for v in I_N.unbind()]
|
||||
>>> jacobian = torch.stack(jacobian_rows)
|
||||
>>>
|
||||
>>> # vectorized gradient computation
|
||||
>>> def get_vjp(v):
|
||||
>>> return torch.autograd.grad(y, x, v)
|
||||
>>> jacobian = functorch.vmap(get_vjp)(I_N)
|
||||
|
||||
.. note::
|
||||
vmap does not provide general autobatching or handle variable-length
|
||||
sequences out of the box.
|
||||
"""
|
||||
warnings.warn(
|
||||
'functorch.vmap is an experimental prototype that is subject to '
|
||||
'change and/or deletion. Please use at your own risk. There may be '
|
||||
'unexpected performance cliffs due to certain operators not being '
|
||||
'implemented. To see detailed performance warnings please use '
|
||||
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
|
||||
'before the call to `vmap`.',
|
||||
stacklevel=2)
|
||||
return _vmap(func, in_dims, out_dims)
|
||||
|
||||
# A version of vmap but without the initial "experimental prototype" warning
|
||||
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args):
|
||||
_check_out_dims_is_int_or_int_tuple(out_dims, func)
|
||||
vmap_level = _vmapmode_increment_nesting()
|
||||
try:
|
||||
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
|
||||
batched_outputs = func(*batched_inputs)
|
||||
_validate_outputs(batched_outputs, func)
|
||||
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
|
||||
finally:
|
||||
_vmapmode_decrement_nesting()
|
||||
return wrapped
|
416
functorch/functorch/csrc/BatchedFallback.cpp
Normal file
416
functorch/functorch/csrc/BatchedFallback.cpp
Normal file
@ -0,0 +1,416 @@
|
||||
#include <functorch/csrc/BatchedFallback.h>
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/Constants.h>
|
||||
#include <functorch/csrc/TensorWrapper.h>
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/MatrixRef.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
// Given a linear index, return the actual index.
|
||||
// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
|
||||
static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
|
||||
computeIndex(int64_t linear_idx, IntArrayRef sizes) {
|
||||
at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
|
||||
result.reserve(sizes.size());
|
||||
for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
|
||||
auto remainder = linear_idx % *it;
|
||||
result.push_back(remainder);
|
||||
linear_idx -= remainder;
|
||||
linear_idx /= *it;
|
||||
}
|
||||
std::reverse(std::begin(result), std::end(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool areAllReturnsTensors(const at::FunctionSchema& schema) {
|
||||
return std::all_of(
|
||||
schema.returns().begin(),
|
||||
schema.returns().end(),
|
||||
[] (const Argument& arg) { return arg.type() == TensorType::get(); });
|
||||
}
|
||||
|
||||
static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
|
||||
return std::any_of(
|
||||
schema.arguments().begin(),
|
||||
schema.arguments().end(),
|
||||
[] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
|
||||
}
|
||||
|
||||
// Returns if an operator is in-place. An operator is inplace if:
|
||||
// 1. The first argument is a Tensor and it is being written to
|
||||
// 2. The first argument is being returned
|
||||
// 3. No other arguments are aliased
|
||||
// Here is an example of an in-place operator:
|
||||
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
static bool isInplaceOp(const FunctionSchema& schema) {
|
||||
if (!schema.is_mutable() || schema.returns().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Check that the first argument is being written to
|
||||
const auto& first_arg_alias_info = schema.arguments().begin()->alias_info();
|
||||
if (!first_arg_alias_info || !first_arg_alias_info.value().isWrite()) {
|
||||
return false;
|
||||
}
|
||||
// Check that none of the other args are being aliased
|
||||
for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
|
||||
const auto& alias_info = it->alias_info();
|
||||
if (alias_info) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Check that the first tensor is being returned (i.e., output has a (a!))
|
||||
const auto& return_alias_info = schema.returns()[0].alias_info();
|
||||
return return_alias_info && return_alias_info.value().isWrite();
|
||||
}
|
||||
|
||||
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
|
||||
if (!globalContext().areVmapFallbackWarningsEnabled()) {
|
||||
return;
|
||||
}
|
||||
auto uses_stack = is_inplace ? "" : " and stack";
|
||||
TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back "
|
||||
"to slow (for loop", uses_stack, ") implementation");
|
||||
}
|
||||
|
||||
// The general flow of the algorithm is as follows.
|
||||
// - First, we figure out which arguments are BatchedTensors and save them
|
||||
// to a vector. We also store a vector of which index of the arguments list
|
||||
// each BatchedTensor appears in. This will be useful for bookkeeping later.
|
||||
// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
|
||||
// This returns a vector of VmapPhysicalView that hold tensors that contain
|
||||
// all of the collective batch dimensions at the front of the tensors.
|
||||
// - Then, we attempt to call `op` once per slice of the inputs. To do this,
|
||||
// we repeatedly we slice the input arguments (if they are BatchedTensors),
|
||||
// put the sliced (or a not-sliced) version of the input onto the stack, invoke
|
||||
// the operator, and then pop the results off the stack.
|
||||
void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
warnFallback(schema, /*in_place*/true);
|
||||
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||
const auto arguments_begin = stack->size() - num_arguments;
|
||||
|
||||
// `self` is the Tensor being modified in-place
|
||||
Tensor self = arguments[0].toTensor();
|
||||
const auto* self_impl = maybeGetBatchedImpl(self);
|
||||
std::bitset<kVmapMaxTensorDims> self_vmap_levels;
|
||||
if (self_impl) {
|
||||
self_vmap_levels = createVmapLevelsBitset(self_impl->bdims());
|
||||
}
|
||||
|
||||
// Figure out which arguments are BatchedTensor. Save them to a vector.
|
||||
// For each BatchedTensor, also record what position of `arguments` they came from.
|
||||
at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
|
||||
VmapDimVector batched_tensor_inputs_position;
|
||||
for (int64_t idx = 0; idx < arguments.size(); ++idx) {
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
}
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
if (!tensor.defined()) {
|
||||
continue;
|
||||
}
|
||||
const auto* batched = maybeGetBatchedImpl(tensor);
|
||||
if (!batched) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// NOTE: [vmap-incompatible in-place operations]
|
||||
// In-place operations on `self` are not possible if there exists some vmap
|
||||
// level `l` such that `self` is not being vmapped on that level but another
|
||||
// argument is. For example, let B0 be a batch dim inside vmap and consider
|
||||
// vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
|
||||
// - self is torch.ones(3) and does not participate in this vmap
|
||||
// - other is BatchedTensor(torch.ones(B0, 3))
|
||||
// There's no way to do self.add_(other) because `other` has more elements
|
||||
// elements than `self` due to being vmapped over.
|
||||
//
|
||||
// In the vmap fallback, we should error out when we detect this.
|
||||
auto other_vmap_levels = createVmapLevelsBitset(batched->bdims());
|
||||
if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
|
||||
// Find one vmap level to complain about
|
||||
auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
|
||||
auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
|
||||
// The following prints out "vmap: aten::add_(tensor, ...) is not possible",
|
||||
// but it would be better to print out "tensor.add_(...) is not possible".
|
||||
// Afaict there's no official way to get the add_ and there is no way to
|
||||
// tell if an operator has method or function variants.
|
||||
TORCH_CHECK(false,
|
||||
"vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
|
||||
"there exists a Tensor `other` in extra_args that has more elements ",
|
||||
"than `self`. This happened due to `other` being vmapped over but ",
|
||||
"`self` not being vmapped over at level ", offending_level, ". ",
|
||||
"Please try to use out-of-place operators instead of ", schema.name(), ". ",
|
||||
"If said operator is being called inside the PyTorch framework, ",
|
||||
"please file a bug report instead.");
|
||||
}
|
||||
batched_tensor_inputs.push_back(tensor);
|
||||
batched_tensor_inputs_position.push_back(idx);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
|
||||
|
||||
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
|
||||
// VmapPhysicalViews that contain all of the batch dimensions.
|
||||
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
|
||||
batched_tensor_inputs);
|
||||
|
||||
// Compute the total number of batches
|
||||
auto num_batch_dims = input_physical_views.front().numBatchDims();
|
||||
auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
|
||||
auto batch_sizes = ArrayRef<int64_t>(
|
||||
first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
|
||||
const auto num_batches = c10::multiply_integers(batch_sizes);
|
||||
// Without a shape-checking API, we're unable to compute the correct shape of
|
||||
// the output so we just error out.
|
||||
TORCH_CHECK(num_batches > 0,
|
||||
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
||||
"The fallback path does not support vmap over dims of size 0.");
|
||||
|
||||
// Strategy: For each batch, we are going to push slices (where applicable)
|
||||
// of the arguments onto `stack`, and call `op`.
|
||||
for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
|
||||
auto index = computeIndex(linear_idx, batch_sizes);
|
||||
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
|
||||
auto input_physical_views_iter = input_physical_views.begin();
|
||||
for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) {
|
||||
// We assume that torch::jit::Stack is backed by vector<IValue> for
|
||||
// simplicity. When that is not the case, this code should be updated.
|
||||
const auto& argument = (*stack)[arguments_begin + arg_idx];
|
||||
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|
||||
|| arg_idx != *batched_tensor_inputs_pos_iter) {
|
||||
// argument isn't a BatchedTensor
|
||||
torch::jit::push(stack, argument);
|
||||
continue;
|
||||
}
|
||||
// argument is a BatchedTensor
|
||||
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
|
||||
const auto& physical_view_for_argument = *input_physical_views_iter;
|
||||
auto thing = physical_view_for_argument.tensor().index(index);
|
||||
torch::jit::push(stack, thing);
|
||||
batched_tensor_inputs_pos_iter++;
|
||||
input_physical_views_iter++;
|
||||
}
|
||||
|
||||
op.callBoxed(stack);
|
||||
torch::jit::drop(stack, 1);
|
||||
}
|
||||
|
||||
// Return the tensor that was written to in-place
|
||||
torch::jit::drop(stack, num_arguments);
|
||||
torch::jit::push(stack, self);
|
||||
}
|
||||
|
||||
static Tensor safeStack(TensorList tensors) {
|
||||
auto is_defined = [](const Tensor& t) { return t.defined(); };
|
||||
if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
|
||||
return at::stack(tensors);
|
||||
}
|
||||
// NOTE [vmap through backward and undefined grad]
|
||||
// While vmapping through backward functions (to compute batched grad), it
|
||||
// is possible for the backward function to return an undefined grad for some
|
||||
// grad_input for each example. In that case, we return an undefined grad.
|
||||
//
|
||||
// It is theoretically posssible for *some* of the examples to produce an
|
||||
// undefined grad (a kernel could peek at the gradient values and return an
|
||||
// undefined tensor if it determines the gradient is full of zeros). We
|
||||
// could handle this by treating the undefined grad as a zero-filled tensor
|
||||
// of the correct shape while stacking the tensors together. However I expect
|
||||
// this to happen very rarely (I have not been able to find an example in our
|
||||
// codebase) so we just error out in this case.
|
||||
if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
|
||||
return Tensor();
|
||||
}
|
||||
TORCH_CHECK(false,
|
||||
"vmap: slow fallback received a mix of undefined and defined tensors ",
|
||||
"as the result of an operation. This is not supported, please file us ",
|
||||
"an issue on github.");
|
||||
}
|
||||
|
||||
// TODO: dedup
|
||||
static bool participatesInCurrentLevel(const Tensor& self) {
|
||||
auto maybe_level = maybeCurrentDynamicLayer();
|
||||
TORCH_INTERNAL_ASSERT(maybe_level.has_value());
|
||||
auto current_level = maybe_level->layerId();
|
||||
auto* maybe_batched_impl = maybeGetBatchedImpl(self);
|
||||
if (!maybe_batched_impl) {
|
||||
return false;
|
||||
}
|
||||
const auto& bdims = maybe_batched_impl->bdims();
|
||||
TORCH_INTERNAL_ASSERT(bdims.size() == 1);
|
||||
auto self_level = bdims.back().level();
|
||||
TORCH_INTERNAL_ASSERT(self_level <= current_level);
|
||||
return self_level == current_level;
|
||||
}
|
||||
|
||||
static bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
|
||||
if (!ivalue.isTensor()) {
|
||||
return false;
|
||||
}
|
||||
return participatesInCurrentLevel(ivalue.toTensor());
|
||||
}
|
||||
|
||||
// The general flow of the algorithm is as follows.
|
||||
// - First, we figure out which arguments are BatchedTensors and save them
|
||||
// to a vector. We also store a vector of which index of the arguments list
|
||||
// each BatchedTensor appears in. This will be useful for bookkeeping later.
|
||||
// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
|
||||
// This returns a vector of VmapPhysicalView that hold tensors that contain
|
||||
// all of the collective batch dimensions at the front of the tensors.
|
||||
// - Then, we attempt to call `op` once per slice of the inputs. To do this,
|
||||
// we repeatedly we slice the input arguments (if they are BatchedTensors),
|
||||
// put the sliced (or a not-sliced) version of the input onto the stack, invoke
|
||||
// the operator, and then pop the results off the stack.
|
||||
// - Each result obtained from the previous step is a slice of the total result,
|
||||
// so we stack those tensors together to form the final result.
|
||||
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||
|
||||
TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
|
||||
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
||||
"We could not generate a fallback.");
|
||||
|
||||
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
op.callBoxed(stack);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isInplaceOp(schema)) {
|
||||
batchedTensorInplaceForLoopFallback(op, stack);
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
|
||||
"Batching rule not implemented for ", schema.operator_name(), "; ",
|
||||
"the fallback path doesn't work on out= or view ops.");
|
||||
TORCH_CHECK(num_returns >= 1,
|
||||
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
||||
"The fallback path does not support operations with no returns.");
|
||||
warnFallback(schema, /*in_place*/false);
|
||||
|
||||
const auto arguments_begin = stack->size() - num_arguments;
|
||||
|
||||
// Figure out which arguments are BatchedTensor. Save them to a vector.
|
||||
// For each BatchedTensor, also record what position of `arguments` they came from.
|
||||
at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
|
||||
VmapDimVector batched_tensor_inputs_position;
|
||||
for (int64_t idx = 0; idx < arguments.size(); ++idx) {
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
}
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
if (!tensor.defined()) {
|
||||
continue;
|
||||
}
|
||||
const auto* batched = maybeGetBatchedImpl(tensor);
|
||||
if (!batched) {
|
||||
continue;
|
||||
}
|
||||
batched_tensor_inputs.push_back(tensor);
|
||||
batched_tensor_inputs_position.push_back(idx);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
|
||||
|
||||
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
|
||||
// VmapPhysicalViews that contain all of the batch dimensions.
|
||||
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
|
||||
batched_tensor_inputs);
|
||||
|
||||
// Compute the total number of batches
|
||||
auto num_batch_dims = input_physical_views.front().numBatchDims();
|
||||
auto some_sizes = input_physical_views.front().tensor().sizes();
|
||||
auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
|
||||
const auto num_batches = c10::multiply_integers(batch_sizes);
|
||||
// Without a shape-checking API, we're unable to compute the correct shape of
|
||||
// the output so we just error out.
|
||||
TORCH_CHECK(num_batches > 0,
|
||||
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
||||
"The fallback path does not support vmap over dims of size 0.");
|
||||
|
||||
// Strategy: For each batch, we are going to push slices (where applicable)
|
||||
// of the arguments onto `stack`, call `op`, and store the result in
|
||||
// `output_shards`.
|
||||
//
|
||||
// NOTE: [Output shards layout]
|
||||
// Assume that the operator has three outputs: a, b, c.
|
||||
// The layout of output_shards is as follows:
|
||||
// [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
|
||||
// This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
|
||||
// more easily in the next step.
|
||||
std::vector<Tensor> output_shards(num_batches * num_returns);
|
||||
|
||||
for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
|
||||
auto index = computeIndex(linear_idx, batch_sizes);
|
||||
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
|
||||
auto input_physical_views_iter = input_physical_views.begin();
|
||||
for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) {
|
||||
// We assume that torch::jit::Stack is backed by vector<IValue> for
|
||||
// simplicity. When that is not the case, this code should be updated.
|
||||
const auto& argument = (*stack)[arguments_begin + arg_idx];
|
||||
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|
||||
|| arg_idx != *batched_tensor_inputs_pos_iter) {
|
||||
// argument isn't a BatchedTensor
|
||||
torch::jit::push(stack, argument);
|
||||
continue;
|
||||
}
|
||||
// argument is a BatchedTensor
|
||||
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
|
||||
const auto& physical_view_for_argument = *input_physical_views_iter;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
|
||||
batched_tensor_inputs_pos_iter++;
|
||||
input_physical_views_iter++;
|
||||
}
|
||||
|
||||
// std::cout << "[Fallback]: ";
|
||||
// at::dump_tensor((*stack)[stack->size() - 1].toTensor());
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
op.callBoxed(stack);
|
||||
|
||||
// Store the result into `output_shards`. See NOTE: [Output shards layout]
|
||||
// to learn about the details of how we store the shards.
|
||||
const auto returns = torch::jit::last(stack, num_returns);
|
||||
for (int64_t return_idx = 0; return_idx < returns.size(); ++return_idx) {
|
||||
output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
|
||||
}
|
||||
torch::jit::drop(stack, num_returns);
|
||||
}
|
||||
|
||||
// For each output Tensor, stack the shards of the tensor together to form a return
|
||||
torch::jit::drop(stack, num_arguments);
|
||||
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
|
||||
for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
|
||||
auto shards = output_shards_chunks[return_idx];
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
auto flat_output = safeStack(shards);
|
||||
// See NOTE [vmap through backward and undefined grad]
|
||||
if (!flat_output.defined()) {
|
||||
torch::jit::push(stack, flat_output);
|
||||
continue;
|
||||
}
|
||||
VmapDimVector output_sizes(batch_sizes);
|
||||
output_sizes.insert(
|
||||
output_sizes.end(),
|
||||
flat_output.sizes().begin() + 1,
|
||||
flat_output.sizes().end());
|
||||
torch::jit::push(
|
||||
stack,
|
||||
input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace at
|
26
functorch/functorch/csrc/BatchedFallback.h
Normal file
26
functorch/functorch/csrc/BatchedFallback.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
// If an operator doesn't have a batching rule implemented then we fallback
|
||||
// to this implementation. The fallback only works on out-of-place operators
|
||||
// that return only tensors with new memory. (e.g., no in-place operators, no
|
||||
// view operations).
|
||||
//
|
||||
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
|
||||
// them, and runs `op` on all of the corresponding slices to produce slices
|
||||
// of the outputs. The output slices then get `torch.stack`ed to create the
|
||||
// final returns.
|
||||
//
|
||||
// The performance of the fallback is not very good because it introduces an
|
||||
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
||||
// write batching rules for operators whenever possible.
|
||||
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
|
||||
|
||||
}
|
||||
} // namespace at
|
169
functorch/functorch/csrc/BatchedTensorImpl.cpp
Normal file
169
functorch/functorch/csrc/BatchedTensorImpl.cpp
Normal file
@ -0,0 +1,169 @@
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <functorch/csrc/Constants.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
|
||||
: TensorImpl(
|
||||
c10::DispatchKeySet(kBatchedKey),
|
||||
value.dtype(),
|
||||
value.device()
|
||||
)
|
||||
, value_(std::move(value))
|
||||
, bdims_(std::move(bdims))
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(value_.defined());
|
||||
set_storage_access_should_throw();
|
||||
checkInvariants();
|
||||
|
||||
const auto public_dims = value_.dim() - bdims_.size();
|
||||
const auto value_sizes = value_.sizes();
|
||||
const auto value_strides = value_.strides();
|
||||
sizes_and_strides_.resize(public_dims);
|
||||
for (int64_t dim = 0; dim < public_dims; dim++) {
|
||||
auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
|
||||
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
|
||||
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
|
||||
}
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, BatchDims bdims)
|
||||
: TensorImpl(
|
||||
key_set.add(kBatchedKey),
|
||||
value.dtype(),
|
||||
value.device()
|
||||
)
|
||||
, value_(std::move(value))
|
||||
, bdims_(std::move(bdims))
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(value_.defined());
|
||||
checkInvariants();
|
||||
|
||||
TORCH_INTERNAL_ASSERT(bdims_.size() == 1);
|
||||
refreshSizesAndStrides();
|
||||
}
|
||||
|
||||
void BatchedTensorImpl::refreshSizesAndStrides() {
|
||||
const auto public_dims = value_.dim() - bdims_.size();
|
||||
const auto value_sizes = value_.sizes();
|
||||
const auto value_strides = value_.strides();
|
||||
sizes_and_strides_.resize(public_dims);
|
||||
for (int64_t dim = 0; dim < public_dims; dim++) {
|
||||
auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
|
||||
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
|
||||
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
|
||||
}
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
|
||||
if (wrap_dim) {
|
||||
const auto ndim = sizes_and_strides_.size();
|
||||
dim = maybe_wrap_dim(dim, ndim);
|
||||
}
|
||||
auto is_bdim = createBatchDimBitset(bdims_);
|
||||
|
||||
// Example: assume dim = 3, and is_bdim = 10010011000...
|
||||
// The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor.
|
||||
// actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent
|
||||
// to asking "where does the 3rd (0-indexed) zero occur in the bitset?".
|
||||
// The answer to that is index 5.
|
||||
//
|
||||
// TODO(rzou): the PDEP instruction does exactly this
|
||||
// (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int)
|
||||
// but it might require newer (>= ~2015) CPUs. We should clean this up
|
||||
// if/when we have dropped support for older CPUs.
|
||||
int64_t non_bdim_count = 0;
|
||||
for (int64_t actual_dim = 0; actual_dim < kVmapMaxTensorDims; actual_dim++) {
|
||||
if (is_bdim[actual_dim]) {
|
||||
continue;
|
||||
}
|
||||
if (non_bdim_count == dim) {
|
||||
return actual_dim;
|
||||
}
|
||||
non_bdim_count++;
|
||||
}
|
||||
// If we hit this assert, then that means
|
||||
// `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number
|
||||
// of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should
|
||||
// never be hit.
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
void BatchedTensorImpl::checkInvariants() const {
|
||||
int64_t prev_level = -1;
|
||||
for (const auto& bdim : bdims_) {
|
||||
TORCH_INTERNAL_ASSERT(bdim.level() > prev_level);
|
||||
prev_level = bdim.level();
|
||||
}
|
||||
}
|
||||
|
||||
// The following are publically exposed as methods of Tensor
|
||||
bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
|
||||
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
|
||||
"NYI: querying is_contiguous inside of vmap for memory_format ",
|
||||
"other than torch.contiguous_format");
|
||||
return is_contiguous_;
|
||||
}
|
||||
|
||||
// The following are some internal inherited methods that we do not support.
|
||||
// They should never get called.
|
||||
void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl");
|
||||
}
|
||||
void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl");
|
||||
}
|
||||
void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl");
|
||||
}
|
||||
#ifdef DEBUG
|
||||
bool BatchedTensorImpl::has_storage() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set");
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const char* BatchedTensorImpl::tensorimpl_type_name() const {
|
||||
return "BatchedTensorImpl";
|
||||
}
|
||||
|
||||
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
|
||||
DispatchKeySet key_set;
|
||||
if (tensor.is_cuda()) {
|
||||
key_set = key_set.add(DispatchKey::CUDA);
|
||||
}
|
||||
return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, std::move(bdims));
|
||||
}
|
||||
|
||||
Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
|
||||
BatchDims new_bdims = { { level, dim } };
|
||||
TORCH_INTERNAL_ASSERT(new_bdims.size() == 1);
|
||||
return makeBatched(tensor, std::move(new_bdims));
|
||||
}
|
||||
|
||||
bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) {
|
||||
const auto* other_batched = maybeGetBatchedImpl(other);
|
||||
if (!other_batched) {
|
||||
return true;
|
||||
}
|
||||
const auto* self_batched = maybeGetBatchedImpl(self);
|
||||
if (!self_batched) {
|
||||
// self is not batched but other is batched
|
||||
return false;
|
||||
}
|
||||
auto self_levels = createVmapLevelsBitset(self_batched->bdims());
|
||||
auto other_levels = createVmapLevelsBitset(other_batched->bdims());
|
||||
return self_levels == (self_levels | other_levels);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace at
|
158
functorch/functorch/csrc/BatchedTensorImpl.h
Normal file
158
functorch/functorch/csrc/BatchedTensorImpl.h
Normal file
@ -0,0 +1,158 @@
|
||||
#pragma once
|
||||
|
||||
#include <bitset>
|
||||
|
||||
#include <ATen/ArrayRef.h>
|
||||
#include <ATen/SmallVector.h>
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
#include <functorch/csrc/Constants.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
using Tensor = at::Tensor;
|
||||
|
||||
// We assume this in a few other places in the codebase,
|
||||
// but there isn't a centralized definition.
|
||||
constexpr int64_t kVmapMaxTensorDims = 64;
|
||||
|
||||
// The valid vmap levels range from [0, 64). This effectively means that we
|
||||
// support a maximum of 64 nested vmaps.
|
||||
constexpr int64_t kVmapNumLevels = 64;
|
||||
|
||||
// Store this number of elements of BatchDims on the stack. Most people will
|
||||
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
||||
constexpr int64_t kBatchDimsStackSize = 5;
|
||||
|
||||
// a BatchDim represents a "private" dimension on a Tensor created inside of
|
||||
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
|
||||
// is being vmap'ed over and the `level` being an identifier for which vmap
|
||||
// said dimension was created inside. The `dim` corresponds to a "physical
|
||||
// dim" - it is a dimension index on the underlying physical tensor that is being
|
||||
// vmapped over.
|
||||
struct BatchDim {
|
||||
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
|
||||
int64_t dim() const {
|
||||
return dim_;
|
||||
}
|
||||
int64_t level() const {
|
||||
return level_;
|
||||
}
|
||||
private:
|
||||
int64_t dim_;
|
||||
int64_t level_;
|
||||
};
|
||||
|
||||
using BatchDims = at::SmallVector<BatchDim, kBatchDimsStackSize>;
|
||||
using BatchDimsRef = at::ArrayRef<BatchDim>;
|
||||
|
||||
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
|
||||
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
||||
// BatchedTensorImpl.
|
||||
//
|
||||
// The batch dimensions are treated as being "private"; they are not user-visible.
|
||||
// For example, in the following Tensor,
|
||||
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
|
||||
// dimensions 0 and 1 are batch dimensions.
|
||||
//
|
||||
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
||||
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
|
||||
struct BatchedTensorImpl : public c10::TensorImpl {
|
||||
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
|
||||
explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, BatchDims bdims);
|
||||
|
||||
// Returns a reference to BatchDims that represent which dimensions of this
|
||||
// tensor are private.
|
||||
BatchDimsRef bdims() const { return bdims_; }
|
||||
|
||||
// BatchedTensorImpl wraps a Tensor
|
||||
const Tensor& value() const { return value_; };
|
||||
|
||||
// Given a public dimension index, return the dimension index in the underlying
|
||||
// value() tensor.
|
||||
// For example, if we have
|
||||
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=2)])
|
||||
// bt.actualDim(0) -> 1
|
||||
// bt.actualDim(1) -> 3
|
||||
// bt.actualDim(2) -> Error
|
||||
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
||||
|
||||
// Override a bunch of methods inherited from TensorImpl to return error messages.
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
|
||||
void set_size(int64_t dim, int64_t new_size) override;
|
||||
void set_stride(int64_t dim, int64_t new_stride) override;
|
||||
void set_storage_offset(int64_t storage_offset) override;
|
||||
#ifdef DEBUG
|
||||
bool has_storage() const override;
|
||||
#endif
|
||||
|
||||
void refreshSizesAndStrides();
|
||||
|
||||
private:
|
||||
// see NOTE: [BatchedTensorImpl levels invariant]
|
||||
void checkInvariants() const;
|
||||
const char* tensorimpl_type_name() const override;
|
||||
|
||||
Tensor value_;
|
||||
|
||||
// Note: [BatchedTensorImpl levels invariant]
|
||||
// There is an invariant that the BatchDims must be stored in increasing `level`
|
||||
// order. That is, for i < j, bdims_[i].level must be less than bdims_[j].level.
|
||||
BatchDims bdims_;
|
||||
};
|
||||
|
||||
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
||||
// BatchedTensorImpl.
|
||||
inline bool isBatchedTensor(const Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(kBatchedKey);
|
||||
}
|
||||
|
||||
// It is unsafe to call this on a Tensor that is not backed by a
|
||||
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
||||
inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
|
||||
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
|
||||
if (!isBatchedTensor(tensor)) {
|
||||
return nullptr;
|
||||
}
|
||||
return unsafeGetBatchedImpl(tensor);
|
||||
}
|
||||
|
||||
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
||||
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(BatchDimsRef bdims) {
|
||||
std::bitset<kVmapMaxTensorDims> is_bdim;
|
||||
for (const auto& bdim : bdims) {
|
||||
is_bdim.set(bdim.dim());
|
||||
}
|
||||
return is_bdim;
|
||||
}
|
||||
|
||||
// Creates a bitset for all of the levels present in `bdims`
|
||||
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
||||
std::bitset<kVmapNumLevels> result;
|
||||
for (const auto& bdim : bdims) {
|
||||
result.set(bdim.level());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
// Use this to construct a BatchedTensor from a regular Tensor
|
||||
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
|
||||
|
||||
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
||||
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
|
||||
|
||||
// Checks if an inplace operation on self and other is "vmap compatible".
|
||||
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
||||
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
|
||||
|
||||
}
|
||||
}
|
155
functorch/functorch/csrc/BatchingMetaprogramming.h
Normal file
155
functorch/functorch/csrc/BatchingMetaprogramming.h
Normal file
@ -0,0 +1,155 @@
|
||||
#pragma once
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
// Metaprogramming things
|
||||
template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
|
||||
template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
|
||||
template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
|
||||
template <typename T> class debug_t;
|
||||
|
||||
// tail operation
|
||||
template<class TypeList>
|
||||
struct tail final {
|
||||
static_assert(c10::guts::false_t<TypeList>::value,
|
||||
"In typelist::tail<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template<class Head, class... Tail>
|
||||
struct tail<typelist<Head, Tail...>> final {
|
||||
using type = typelist<Tail...>;
|
||||
};
|
||||
template<class TypeList> using tail_t = typename tail<TypeList>::type;
|
||||
|
||||
template <class First, class Second, class Next, class Tail>
|
||||
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
|
||||
using type = Next;
|
||||
};
|
||||
template <class Next, class Tail>
|
||||
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, optional<int64_t>, Next, Tail> {
|
||||
using type = Tail;
|
||||
};
|
||||
template <class Next, class Tail>
|
||||
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, optional<int64_t>, Next, Tail> {
|
||||
using type = Tail;
|
||||
};
|
||||
template <class Next, class Tail>
|
||||
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, optional<int64_t>, Next, Tail> {
|
||||
using type = Tail;
|
||||
};
|
||||
template <class TypeList> struct RemoveBatchDimAfterTensor {
|
||||
using first = head_t<TypeList>;
|
||||
using next = tail_t<TypeList>;
|
||||
using second = head_t<next>;
|
||||
using tail = tail_t<next>;
|
||||
|
||||
using type = concat_t<
|
||||
typelist<first>,
|
||||
typename RemoveBatchDimAfterTensor<
|
||||
typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
|
||||
>::type
|
||||
>;
|
||||
};
|
||||
template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
|
||||
using type = typelist<Type>;
|
||||
};
|
||||
template <> struct RemoveBatchDimAfterTensor<typelist<>> {
|
||||
using type = typelist<>;
|
||||
};
|
||||
template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
|
||||
|
||||
// TODO: get rid of these
|
||||
// Do I need templates on templates now?
|
||||
// template <typename func_t> struct LowerToNextLayer {};
|
||||
// template <typename Return, typename... Args> struct LowerToNextLayer<Return(Args...)> {
|
||||
// // How to pass in batch_rule directly?
|
||||
// static Return apply(Args... args);
|
||||
// };
|
||||
|
||||
template <typename batch_rule_t, typename Result, typename... Args>
|
||||
Result lowerToNextLayer(batch_rule_t batch_rule, Args... args);
|
||||
|
||||
//# Tensor lowerToNextLayer(
|
||||
//# std::function<std::tuple<Tensor,optional<int64_t>>(const Tensor&, optional<int64_t>)> batch_rule,
|
||||
//# const Tensor& tensor);
|
||||
std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim);
|
||||
|
||||
template<typename F, F Func, typename Return, typename TupleArgs> struct TORCH_API Dummy {};
|
||||
|
||||
template<typename F, F Func, typename Return, typename...T> struct Dummy<F, Func, Return, std::tuple<T...>> {
|
||||
static Return apply(T... args) {
|
||||
return lowerToNextLayer(abs_batch_rule, std::forward<T>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct UnpackSingleItemTuple {
|
||||
using type = T;
|
||||
};
|
||||
template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
|
||||
using type = T;
|
||||
};
|
||||
template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
|
||||
|
||||
template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
|
||||
template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
|
||||
using type = Return(Args...);
|
||||
};
|
||||
template <typename Return, typename TL>
|
||||
struct BuildFunction {
|
||||
using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
|
||||
};
|
||||
template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
|
||||
|
||||
|
||||
// std::tuple<Tensor,optional<int64_t>> (*kAbsBatchRule)(const Tensor& Tensor, optional<int64_t>)
|
||||
// = &abs_batch_rule;
|
||||
template <typename batch_rule_t> struct ToOperatorType {
|
||||
using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
|
||||
using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
|
||||
|
||||
using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
|
||||
using operator_return_type =
|
||||
unpack_single_item_tuple_t<
|
||||
c10::guts::typelist::to_tuple_t<
|
||||
remove_batch_dim_after_tensor_t<
|
||||
c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
|
||||
|
||||
using type = build_function_t<operator_return_type, operator_parameter_types>;
|
||||
};
|
||||
template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
|
||||
|
||||
template <typename F, F Func> struct TORCH_API PrimBatchRule3 {
|
||||
using func_t = to_operator_t<typename std::remove_pointer<F>::type>;
|
||||
using result_type = typename c10::guts::function_traits<func_t>::return_type;
|
||||
using parameter_types = c10::guts::typelist::to_tuple_t<typename c10::guts::function_traits<func_t>::parameter_types>;
|
||||
static auto apply = Dummy<F, Func, result_type, parameter_types>::apply;
|
||||
};
|
||||
|
||||
template<typename Return, typename TypeList> struct TORCH_API PrimBatchRule5 {};
|
||||
template<typename Return, typename... T> struct PrimBatchRule5<Return, typelist<T...>> {
|
||||
static inline Return apply(T... args) {
|
||||
return lowerToNextLayer(abs_batch_rule, std::forward<T>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename func_t> struct PrimBatchRule6 {};
|
||||
template<typename Return, typename... Args> struct PrimBatchRule6<Return (Args...)> {
|
||||
static inline Return apply(Args... args) {
|
||||
return lowerToNextLayer(abs_batch_rule, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
// template<typename batch_rule_t, batch_rule_t BatchRule> struct PrimBatchRule7 {};
|
||||
// template<typename batch_rule_t, batch_rule_t BatchRule, typename BRReturn, typename... BRArgs>
|
||||
// struct PrimBatchRule7<BRReturn(*)(BRArgs...), BatchRule> {
|
||||
template<typename br_t, br_t BatchRule, typename func_t> struct PrimBatchRule7 {};
|
||||
template<typename br_t, br_t BatchRule, typename Return, typename... Args> struct PrimBatchRule7<
|
||||
br_t, BatchRule, Return (Args...)> {
|
||||
static inline Return apply(Args... args) {
|
||||
return lowerToNextLayer<br_t, Return, Args...>(BatchRule, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace at
|
1700
functorch/functorch/csrc/BatchingRegistrations.cpp
Normal file
1700
functorch/functorch/csrc/BatchingRegistrations.cpp
Normal file
File diff suppressed because it is too large
Load Diff
9
functorch/functorch/csrc/Constants.h
Normal file
9
functorch/functorch/csrc/Constants.h
Normal file
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
#include <c10/core/DispatchKey.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
constexpr auto kBatchedKey = c10::DispatchKey::BatchedOutOfTree;
|
||||
|
||||
}} // namespace at::functorch
|
377
functorch/functorch/csrc/DynamicLayer.cpp
Normal file
377
functorch/functorch/csrc/DynamicLayer.cpp
Normal file
@ -0,0 +1,377 @@
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/TensorWrapper.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <c10/util/ThreadLocalDebugInfo.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
// Initial autograd layer, because autograd is always "on"
|
||||
// thread_local std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
|
||||
|
||||
using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
|
||||
DynmetaData kDynMetaDataSingleton;
|
||||
|
||||
static DynmetaData& getGlobalDynmetaData() {
|
||||
return kDynMetaDataSingleton;
|
||||
}
|
||||
|
||||
class DynamicLayerStackHolder : public c10::DebugInfoBase {
|
||||
public:
|
||||
DynamicLayerStackHolder() {}
|
||||
virtual ~DynamicLayerStackHolder() {}
|
||||
|
||||
std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
|
||||
};
|
||||
|
||||
thread_local std::shared_ptr<DynamicLayerStackHolder> kDynamicLayerStack;
|
||||
|
||||
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
|
||||
if (kDynamicLayerStack == nullptr) {
|
||||
kDynamicLayerStack = std::make_shared<DynamicLayerStackHolder>();
|
||||
c10::ThreadLocalDebugInfo::_push(
|
||||
// TODO: this isn't a PRODUCER_INFO, but there's nothing else we can use
|
||||
c10::DebugInfoKind::PRODUCER_INFO,
|
||||
kDynamicLayerStack);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(kDynamicLayerStack != nullptr);
|
||||
// TODO: can figure out how to memoize this. std::call_once with thread_local?
|
||||
return kDynamicLayerStack->dynamicLayerStack;
|
||||
}
|
||||
|
||||
std::shared_ptr<bool> getLifeHandleForLevel(int64_t level) {
|
||||
auto it = getGlobalDynmetaData().find(level);
|
||||
TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
optional<DynamicLayer> maybeCurrentDynamicLayer() {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
// NB: Exception for regular autograd, maybe tweak this
|
||||
if (dynamicLayerStack.size() <= 1) {
|
||||
return {};
|
||||
}
|
||||
return dynamicLayerStack.back();
|
||||
}
|
||||
|
||||
const std::vector<DynamicLayer>& getDynamicLayerStack() {
|
||||
return dynamicLayerStackAccessor();
|
||||
}
|
||||
|
||||
void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
|
||||
dynamicLayerStackAccessor() = stack;
|
||||
}
|
||||
|
||||
static DynamicLayer popDynamicLayer() {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
|
||||
auto result = dynamicLayerStack.back();
|
||||
TORCH_INTERNAL_ASSERT(result.key() != DispatchKey::Undefined);
|
||||
dynamicLayerStack.pop_back();
|
||||
|
||||
if (dynamicLayerStack.size() == 0) {
|
||||
// std::cout << "DynamicLayer off" << std::endl;
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, false);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, false);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static int64_t pushDynamicLayer(DispatchKey key) {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
TORCH_INTERNAL_ASSERT(key != DispatchKey::Undefined);
|
||||
TORCH_INTERNAL_ASSERT(key != DispatchKey::Batched);
|
||||
auto layerId = 1 + dynamicLayerStack.size();
|
||||
dynamicLayerStack.emplace_back(key, layerId);
|
||||
|
||||
if (layerId == 2) {
|
||||
// std::cout << "DynamicLayer on" << std::endl;
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
|
||||
}
|
||||
|
||||
return layerId;
|
||||
}
|
||||
|
||||
int64_t initAndPushDynamicLayer(DispatchKey key) {
|
||||
auto layerId = pushDynamicLayer(key);
|
||||
auto& data = getGlobalDynmetaData();
|
||||
TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
|
||||
data[layerId] = std::make_shared<bool>(true);
|
||||
return layerId;
|
||||
}
|
||||
|
||||
DynamicLayer popDynamicLayerAndDeleteMetadata() {
|
||||
auto result = popDynamicLayer();
|
||||
auto level = result.layerId();
|
||||
|
||||
// TODO: is this lock safe? No one else should be writing to the same bucket
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "deleting metadata" << std::endl;
|
||||
}
|
||||
auto& data = getGlobalDynmetaData();
|
||||
auto it = data.find(level);
|
||||
if (it == data.end()) {
|
||||
return result;
|
||||
}
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "deleted metadata for level " << level << std::endl;
|
||||
}
|
||||
// invalidate the thing
|
||||
*(it->second) = false;
|
||||
data.erase(level);
|
||||
return result;
|
||||
}
|
||||
|
||||
static Tensor materializeGradWrappers(const Tensor& tensor, const std::vector<DynamicLayer>& dynlayerStack) {
|
||||
if (!tensor.defined()) {
|
||||
return tensor;
|
||||
}
|
||||
// TODO: First entry in the stack is a default autograd key.
|
||||
// We should clean up the logic
|
||||
if (dynlayerStack.size() <= 1) {
|
||||
return tensor;
|
||||
}
|
||||
if (dynlayerStack.back().key() != DispatchKey::Autograd) {
|
||||
return tensor;
|
||||
}
|
||||
auto cur_level = dynlayerStack.back().layerId();
|
||||
auto* wrapper = maybeGetTensorWrapper(tensor);
|
||||
if (!wrapper) {
|
||||
return makeTensorWrapper(tensor, cur_level);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(wrapper->level().value() <= cur_level, "escaped?");
|
||||
if (wrapper->level().value() == cur_level) {
|
||||
TORCH_INTERNAL_ASSERT(tensor.defined());
|
||||
return tensor;
|
||||
}
|
||||
return makeTensorWrapper(tensor, cur_level);
|
||||
}
|
||||
|
||||
static Tensor unwrapIfDead(const Tensor& tensor) {
|
||||
auto* wrapped = maybeGetTensorWrapper(tensor);
|
||||
if (!wrapped) {
|
||||
return tensor;
|
||||
}
|
||||
if (wrapped->is_alive()) {
|
||||
return tensor;
|
||||
}
|
||||
return wrapped->value();
|
||||
}
|
||||
|
||||
static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
|
||||
std::function<Tensor(const Tensor&)> func) {
|
||||
TORCH_INTERNAL_ASSERT(begin >= 0);
|
||||
TORCH_INTERNAL_ASSERT(end >= 0);
|
||||
TORCH_INTERNAL_ASSERT(begin <= end);
|
||||
for (int64_t idx = begin; idx < end; idx++) {
|
||||
auto ivalue = args[idx];
|
||||
if (ivalue.isTensorList()) {
|
||||
auto list = ivalue.toTensorList();
|
||||
for (int64_t list_idx = 0; list_idx < list.size(); list_idx++) {
|
||||
list[list_idx] = func(list[list_idx]);
|
||||
}
|
||||
args[idx] = list;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
}
|
||||
Tensor value = ivalue.toTensor();
|
||||
Tensor replacement = func(value);
|
||||
args[idx] = std::move(replacement);
|
||||
// sanity checks
|
||||
if (ivalue.toTensor().defined()) {
|
||||
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
|
||||
DispatchKey::DynamicLayerFront,
|
||||
DispatchKey::DynamicLayerBack,
|
||||
DispatchKey::TensorWrapper,
|
||||
// DispatchKey::Batched,
|
||||
DispatchKey::BatchedOutOfTree,
|
||||
DispatchKey::InplaceOrView
|
||||
}) | autograd_dispatch_keyset;
|
||||
|
||||
static void sanityCheckStack(torch::jit::Stack* stack) {
|
||||
if (stack->size() > 0) {
|
||||
auto last_ivalue = (*stack)[stack->size() - 1];
|
||||
if (last_ivalue.isTensor()) {
|
||||
auto tensor = last_ivalue.toTensor();
|
||||
auto* wrapper = maybeGetTensorWrapper(tensor);
|
||||
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
|
||||
TORCH_INTERNAL_ASSERT(tensor.has_storage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
|
||||
}
|
||||
if (dynamicLayerStack.size() == 0) {
|
||||
sanityCheckStack(stack);
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
|
||||
op.callBoxed(stack);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unwrap dead GradWrappers, materialize live ones
|
||||
auto maybeTransformGradWrappers = [](const Tensor& tensor) {
|
||||
auto result = unwrapIfDead(tensor);
|
||||
return materializeGradWrappers(result, getDynamicLayerStack());
|
||||
};
|
||||
auto num_args = op.schema().arguments().size();
|
||||
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers);
|
||||
|
||||
auto layer = dynamicLayerStack.back();
|
||||
|
||||
DispatchKeySet exclude = DispatchKeySet::FULL;
|
||||
exclude = exclude.remove(DispatchKey::DynamicLayerBack);
|
||||
if (layer.key() == DispatchKey::Autograd) {
|
||||
exclude = exclude - autograd_dispatch_keyset;
|
||||
exclude = exclude.remove(DispatchKey::InplaceOrView);
|
||||
// } else if (layer.key() == DispatchKey::Batched) {
|
||||
// exclude = exclude.remove(DispatchKey::Batched);
|
||||
} else if (layer.key() == DispatchKey::BatchedOutOfTree) {
|
||||
exclude = exclude.remove(DispatchKey::BatchedOutOfTree);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude);
|
||||
|
||||
// Re-dispatch
|
||||
op.callBoxed(stack);
|
||||
}
|
||||
|
||||
struct WithoutTop {
|
||||
WithoutTop(): layer_(popDynamicLayer()) {
|
||||
}
|
||||
~WithoutTop() {
|
||||
pushDynamicLayer(layer_.key());
|
||||
}
|
||||
|
||||
bool prev_grad_enabled_;
|
||||
DynamicLayer layer_;
|
||||
};
|
||||
|
||||
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto cur_level = getDynamicLayerStack().back().layerId();
|
||||
auto cur_key = getDynamicLayerStack().back().key();
|
||||
|
||||
auto unwrap = [&](const Tensor& tensor) {
|
||||
if (!tensor.defined()) {
|
||||
return tensor;
|
||||
}
|
||||
auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
|
||||
if (!maybe_tensor_wrapper) {
|
||||
return tensor;
|
||||
}
|
||||
if (maybe_tensor_wrapper->level().value() == cur_level) {
|
||||
return maybe_tensor_wrapper->value();
|
||||
}
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "unwrap " << cur_level << std::endl;
|
||||
}
|
||||
return tensor;
|
||||
};
|
||||
auto wrap = [&](const Tensor& tensor) {
|
||||
if (!tensor.defined()) {
|
||||
return tensor;
|
||||
}
|
||||
if (cur_level == 1) {
|
||||
return tensor;
|
||||
}
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "wrap " << cur_level << std::endl;
|
||||
}
|
||||
return makeTensorWrapper(tensor, cur_level);
|
||||
};
|
||||
|
||||
// TODO: we only need to do the following (marked with !) on in-place functions
|
||||
// that modify sizes or strides. There aren't many of them.
|
||||
// If autograd dispatch key:
|
||||
// 1. (!) Put a copy of all of the args onto the stack
|
||||
// 2. Unwrap all the args in the copy set
|
||||
// 3. Call the operator
|
||||
// 4. Wrap the output
|
||||
// 5. (!) refreshSizesAndStrides for all the args in the original set
|
||||
// 6. (!) Pop those args off.
|
||||
|
||||
// Step 1 & 2
|
||||
if (cur_key == DispatchKey::Autograd) {
|
||||
auto args_size = op.schema().arguments().size();
|
||||
// Step 1
|
||||
auto front = stack->size() - args_size;
|
||||
for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
|
||||
stack->push_back((*stack)[front + arg_idx]);
|
||||
}
|
||||
// Step 2
|
||||
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
|
||||
}
|
||||
|
||||
// pop the top layer. Put it back on dtor.
|
||||
WithoutTop guard;
|
||||
|
||||
// "reset exclude set"
|
||||
// TODO: Still a problem with composabiilty and AutoNonVariableTypeGuard.
|
||||
// Users cannot do torch.no_grad otherwise there will be problems.
|
||||
auto keyset = c10::impl::PODLocalDispatchKeySet();
|
||||
c10::impl::_force_tls_local_dispatch_key_set(keyset);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
|
||||
|
||||
// Re-dispatch
|
||||
op.callBoxed(stack);
|
||||
|
||||
// Step 4, 5, 6
|
||||
if (cur_key == DispatchKey::Autograd) {
|
||||
// Step 4
|
||||
auto ret_size = op.schema().returns().size();
|
||||
foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
|
||||
|
||||
// Step 5
|
||||
auto args_size = op.schema().arguments().size();
|
||||
auto args_front = stack->size() - args_size - ret_size;
|
||||
for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
|
||||
auto& ivalue = (*stack)[args_front + arg_idx];
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
}
|
||||
auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
|
||||
if (!maybe_tensor_wrapper) {
|
||||
continue;
|
||||
}
|
||||
maybe_tensor_wrapper->refreshSizesAndStrides();
|
||||
}
|
||||
|
||||
// Step 6
|
||||
stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, DynamicLayerFront, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, DynamicLayerBack, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
|
||||
}
|
||||
|
||||
// TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
|
||||
// m.impl("_unwrap_for_grad", native::_unwrap_for_grad);
|
||||
// m.impl("dump_tensor", native::dump_tensor);
|
||||
// m.impl("dlevel", native::dlevel);
|
||||
// }
|
||||
|
||||
}
|
||||
} // namespace at
|
33
functorch/functorch/csrc/DynamicLayer.h
Normal file
33
functorch/functorch/csrc/DynamicLayer.h
Normal file
@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
|
||||
// Forward declared bc I am lazy
|
||||
namespace c10 { struct AutogradMetaInterface; }
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
struct TORCH_API DynamicLayer {
|
||||
DynamicLayer(DispatchKey key, int64_t layerId): key_(key), layerId_(layerId) {}
|
||||
|
||||
DispatchKey key() const { return key_; }
|
||||
int64_t layerId() const { return layerId_; }
|
||||
private:
|
||||
DispatchKey key_;
|
||||
int64_t layerId_;
|
||||
};
|
||||
|
||||
TORCH_API int64_t initAndPushDynamicLayer(DispatchKey key);
|
||||
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
|
||||
TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer();
|
||||
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
|
||||
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
|
||||
|
||||
// NB: not lock safe. TODO: does it need a lock?
|
||||
TORCH_API std::shared_ptr<bool> getLifeHandleForLevel(int64_t level);
|
||||
|
||||
}
|
||||
} // namespace at
|
232
functorch/functorch/csrc/TensorWrapper.cpp
Normal file
232
functorch/functorch/csrc/TensorWrapper.cpp
Normal file
@ -0,0 +1,232 @@
|
||||
#include <functorch/csrc/TensorWrapper.h>
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
auto* wrapped = maybeGetTensorWrapper(tensor);
|
||||
if (!wrapped) {
|
||||
auto* batched = maybeGetBatchedImpl(tensor);
|
||||
if (batched) {
|
||||
ss << "Batched[" << batched->bdims() << ", ";
|
||||
dumpTensor(ss, batched->value());
|
||||
ss << "]";
|
||||
return;
|
||||
}
|
||||
ss << "Tensor" << tensor.sizes();
|
||||
return;
|
||||
}
|
||||
if (wrapped->is_alive()) {
|
||||
ss << "Wrapper[";
|
||||
} else {
|
||||
ss << "Wrapper[";
|
||||
}
|
||||
if (wrapped->level().has_value()) {
|
||||
ss << wrapped->level().value() << ", ";
|
||||
} else {
|
||||
ss << "dead, ";
|
||||
}
|
||||
dumpTensor(ss, wrapped->value());
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
void TensorWrapper::refreshSizesAndStrides() {
|
||||
auto dim = value_.dim();
|
||||
auto sizes = value_.sizes();
|
||||
auto strides = value_.strides();
|
||||
sizes_and_strides_.resize(value_.dim());
|
||||
for (int64_t i = 0; i < dim; i++) {
|
||||
sizes_and_strides_.size_at_unchecked(i) = sizes[i];
|
||||
sizes_and_strides_.stride_at_unchecked(i) = strides[i];
|
||||
}
|
||||
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
void dumpTensorCout(const Tensor& tensor) {
|
||||
dumpTensor(std::cout, tensor);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) {
|
||||
// TODO: denylist non-cuda/cpu backends to avoid funny business
|
||||
DispatchKeySet key_set;
|
||||
if (tensor.is_cuda()) {
|
||||
key_set = key_set.add(DispatchKey::CUDA);
|
||||
key_set = key_set.add(DispatchKey::AutogradCUDA);
|
||||
} else {
|
||||
key_set = key_set.add(DispatchKey::CPU);
|
||||
key_set = key_set.add(DispatchKey::AutogradCPU);
|
||||
}
|
||||
key_set = key_set.add(DispatchKey::TensorWrapper);
|
||||
if (should_be_alive) {
|
||||
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
|
||||
} else {
|
||||
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, std::make_shared<bool>(false));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) {
|
||||
auto wrapped = maybeGetTensorWrapper(tensor);
|
||||
if (wrapped) {
|
||||
TORCH_INTERNAL_ASSERT(wrapped->level() < level);
|
||||
}
|
||||
|
||||
// TODO: denylist non-cuda/cpu backends to avoid funny business
|
||||
DispatchKeySet key_set;
|
||||
if (tensor.is_cuda()) {
|
||||
key_set = key_set.add(DispatchKey::CUDA);
|
||||
key_set = key_set.add(DispatchKey::AutogradCUDA);
|
||||
} else {
|
||||
key_set = key_set.add(DispatchKey::CPU);
|
||||
key_set = key_set.add(DispatchKey::AutogradCPU);
|
||||
}
|
||||
key_set = key_set.add(DispatchKey::TensorWrapper);
|
||||
auto life_handle = getLifeHandleForLevel(level);
|
||||
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle));
|
||||
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::TensorWrapper));
|
||||
return result;
|
||||
}
|
||||
|
||||
bool TensorWrapper::is_alive() const {
|
||||
return *is_alive_;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
|
||||
dest_impl->set_version_counter(version_counter);
|
||||
|
||||
// TODO: is this even right?
|
||||
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return dest_impl;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
|
||||
dest_impl->set_version_counter(version_counter);
|
||||
|
||||
// TODO: is this even right?
|
||||
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return dest_impl;
|
||||
}
|
||||
|
||||
void TensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
|
||||
TORCH_INTERNAL_ASSERT(false, "NYI");
|
||||
}
|
||||
|
||||
TensorWrapper::TensorWrapper(
|
||||
c10::DispatchKeySet key_set,
|
||||
Tensor value,
|
||||
int64_t level,
|
||||
std::shared_ptr<bool> is_alive,
|
||||
bool use_value_sizes_strides)
|
||||
: TensorImpl(key_set, value.dtype(), value.device())
|
||||
, value_(std::move(value))
|
||||
, level_(level)
|
||||
, is_alive_(std::move(is_alive))
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(value_.defined());
|
||||
set_storage_access_should_throw();
|
||||
|
||||
// TODO: need to reset sizes/strides on mutation
|
||||
TORCH_INTERNAL_ASSERT(use_value_sizes_strides);
|
||||
refreshSizesAndStrides();
|
||||
}
|
||||
|
||||
// The following are some internal inherited methods that we do not support.
|
||||
// They should never get called.
|
||||
void TensorWrapper::set_size(int64_t dim, int64_t new_size) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_size for TensorWrapper");
|
||||
}
|
||||
void TensorWrapper::set_stride(int64_t dim, int64_t new_stride) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for TensorWrapper");
|
||||
}
|
||||
void TensorWrapper::set_storage_offset(int64_t storage_offset) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper");
|
||||
}
|
||||
|
||||
const char* TensorWrapper::tensorimpl_type_name() const {
|
||||
return "TensorWrapper";
|
||||
}
|
||||
|
||||
|
||||
TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) {
|
||||
if (!tensor.key_set().has(DispatchKey::TensorWrapper)) {
|
||||
return nullptr;
|
||||
}
|
||||
return (TensorWrapper*)(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
|
||||
std::function<Tensor(const Tensor&)> func) {
|
||||
TORCH_INTERNAL_ASSERT(begin >= 0);
|
||||
TORCH_INTERNAL_ASSERT(end >= 0);
|
||||
TORCH_INTERNAL_ASSERT(begin <= end);
|
||||
for (int64_t idx = begin; idx < end; idx++) {
|
||||
auto ivalue = args[idx];
|
||||
if (ivalue.isTensorList()) {
|
||||
TORCH_INTERNAL_ASSERT(false, "NYI: TensorList");
|
||||
}
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
}
|
||||
Tensor value = ivalue.toTensor();
|
||||
Tensor replacement = func(value);
|
||||
args[idx] = replacement; // TODO: std::move?
|
||||
if (ivalue.toTensor().defined()) {
|
||||
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor unwrapIfDead(const Tensor& tensor) {
|
||||
auto* wrapped = maybeGetTensorWrapper(tensor);
|
||||
if (!wrapped) {
|
||||
return tensor;
|
||||
}
|
||||
if (wrapped->is_alive()) {
|
||||
return tensor;
|
||||
}
|
||||
return wrapped->value();
|
||||
}
|
||||
|
||||
void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto args_size = op.schema().arguments().size();
|
||||
int64_t unwrapped_count = 0;
|
||||
auto unwrapIfDeadAndIncrement = [&](const Tensor& tensor) {
|
||||
auto* wrapped = maybeGetTensorWrapper(tensor);
|
||||
if (!wrapped) {
|
||||
return tensor;
|
||||
}
|
||||
if (wrapped->is_alive()) {
|
||||
return tensor;
|
||||
}
|
||||
unwrapped_count++;
|
||||
return wrapped->value();
|
||||
};
|
||||
|
||||
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrapIfDeadAndIncrement);
|
||||
TORCH_INTERNAL_ASSERT(unwrapped_count > 0, "Should have at least one dead wrapper");
|
||||
|
||||
// re-dispatch
|
||||
op.callBoxed(stack);
|
||||
}
|
||||
|
||||
// TensorWrapper backend fallback: Unwrap and fallthrough.
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, TensorWrapper, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>());
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace at
|
61
functorch/functorch/csrc/TensorWrapper.h
Normal file
61
functorch/functorch/csrc/TensorWrapper.h
Normal file
@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
struct TORCH_API TensorWrapper : public c10::TensorImpl {
|
||||
explicit TensorWrapper(
|
||||
c10::DispatchKeySet key_set,
|
||||
Tensor value,
|
||||
int64_t level,
|
||||
std::shared_ptr<bool> is_alive,
|
||||
bool use_value_sizes_strides = true);
|
||||
|
||||
// Override a bunch of methods inherited from TensorImpl to return error messages
|
||||
void set_size(int64_t dim, int64_t new_size) override;
|
||||
void set_stride(int64_t dim, int64_t new_stride) override;
|
||||
void set_storage_offset(int64_t storage_offset) override;
|
||||
|
||||
void refreshSizesAndStrides();
|
||||
|
||||
const Tensor& value() const {
|
||||
return value_;
|
||||
}
|
||||
optional<int64_t> level() const {
|
||||
if (is_alive()) {
|
||||
return level_;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
bool is_alive() const;
|
||||
|
||||
// Overrides necessary for autograd
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const override;
|
||||
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const override;
|
||||
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
||||
|
||||
private:
|
||||
const char* tensorimpl_type_name() const override;
|
||||
Tensor value_;
|
||||
int64_t level_;
|
||||
|
||||
// When we exit the level, this wrapper may be marked as "not alive".
|
||||
// Wrappers that are not alive:
|
||||
// 1) May still have autograd metadata on them
|
||||
// 2) Forward dispatches to the underlying value()
|
||||
std::shared_ptr<bool> is_alive_;
|
||||
};
|
||||
|
||||
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level);
|
||||
TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
|
||||
TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
|
||||
TORCH_API void dumpTensorCout(const Tensor& tensor);
|
||||
|
||||
}
|
||||
} // namespace at
|
60
functorch/functorch/csrc/VmapMode.cpp
Normal file
60
functorch/functorch/csrc/VmapMode.cpp
Normal file
@ -0,0 +1,60 @@
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/VmapMode.h>
|
||||
#include <functorch/csrc/Constants.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
namespace impl {
|
||||
|
||||
/// thread_local is a feature that is not enabled by Caffe2 mobile
|
||||
/// build (e.g. iOS). Therefore, we only provide `at::VmapMode`
|
||||
/// when we are not in mobile build or when FEATURE_TORCH_MOBILE
|
||||
/// is on.
|
||||
#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
|
||||
|
||||
thread_local int64_t VmapMode_current_vmap_level = 0;
|
||||
|
||||
int64_t VmapMode::current_vmap_level() {
|
||||
return VmapMode_current_vmap_level;
|
||||
}
|
||||
|
||||
int64_t VmapMode::increment_nesting() {
|
||||
VmapMode_current_vmap_level++;
|
||||
|
||||
auto level = initAndPushDynamicLayer(kBatchedKey);
|
||||
if (VmapMode_current_vmap_level == 1) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
|
||||
}
|
||||
return level;
|
||||
}
|
||||
|
||||
int64_t VmapMode::decrement_nesting() {
|
||||
VmapMode_current_vmap_level--;
|
||||
auto layer = popDynamicLayerAndDeleteMetadata();
|
||||
TORCH_INTERNAL_ASSERT(layer.key() == kBatchedKey);
|
||||
if (VmapMode_current_vmap_level == 0) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
|
||||
}
|
||||
// TODO: this return value should never be used
|
||||
return VmapMode_current_vmap_level;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int64_t VmapMode::current_nesting_level() {
|
||||
TORCH_CHECK(false, "VmapMode is not supported on mobile");
|
||||
}
|
||||
|
||||
int64_t VmapMode::increment_nesting() {
|
||||
TORCH_CHECK(false, "VmapMode is not supported on mobile");
|
||||
}
|
||||
|
||||
int64_t VmapMode::decrement_nesting() {
|
||||
TORCH_CHECK(false, "VmapMode is not supported on mobile");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace impl
|
||||
}
|
||||
} // namespace at
|
30
functorch/functorch/csrc/VmapMode.h
Normal file
30
functorch/functorch/csrc/VmapMode.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
namespace impl {
|
||||
|
||||
// VmapMode contains a thread local count of how many nested vmaps
|
||||
// we are currently inside. That number is known as the `vmap level`.
|
||||
// VmapMode is used in the implementation of the Python `torch.vmap` API.
|
||||
//
|
||||
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
|
||||
|
||||
struct TORCH_API VmapMode {
|
||||
// Returns the vmap level, aka the count of how many nested vmaps we're in.
|
||||
static int64_t current_vmap_level();
|
||||
|
||||
// Increment the count of nested vmaps. If this causes the vmap level to be
|
||||
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
|
||||
static int64_t increment_nesting();
|
||||
|
||||
// Decrements the count of nested vmaps. If this causes the vmap level to be
|
||||
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
|
||||
static int64_t decrement_nesting();
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
}
|
||||
} // namespace at
|
323
functorch/functorch/csrc/VmapTransforms.cpp
Normal file
323
functorch/functorch/csrc/VmapTransforms.cpp
Normal file
@ -0,0 +1,323 @@
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
// Checks if the batch dims in `bdims` appear at the front of the tensor.
|
||||
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
|
||||
for (int64_t idx = 0; idx < bdims.size(); idx++) {
|
||||
if (bdims[idx].dim() != idx) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
|
||||
// and then returns a physical version of the Tensor.
|
||||
static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
|
||||
auto bdims = batched->bdims();
|
||||
const Tensor& physical_tensor = batched->value();
|
||||
if (areBdimsAtFrontInOrder(bdims)) {
|
||||
return physical_tensor;
|
||||
}
|
||||
const auto sizes = physical_tensor.sizes();
|
||||
VmapDimVector permutation(sizes.size(), 0);
|
||||
permutation.reserve(sizes.size());
|
||||
const auto is_bdim = createBatchDimBitset(bdims);
|
||||
int64_t idx = 0;
|
||||
for (const auto& bdim : bdims) {
|
||||
permutation[idx++] = bdim.dim();
|
||||
}
|
||||
for (int64_t ptr = 0; idx < sizes.size(); ptr++) {
|
||||
if (is_bdim[ptr]) {
|
||||
continue;
|
||||
}
|
||||
permutation[idx++] = ptr;
|
||||
}
|
||||
return physical_tensor.permute(permutation);
|
||||
}
|
||||
|
||||
VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
|
||||
auto* batched = maybeGetBatchedImpl(logical_tensor);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
batched,
|
||||
"logicalToPhysical(tensor) should only be passed a BatchedTensor");
|
||||
return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
|
||||
}
|
||||
|
||||
int64_t VmapPhysicalView::numBatchDims() const {
|
||||
return levels_.count();
|
||||
}
|
||||
|
||||
int64_t VmapPhysicalView::numLogicalDims() const {
|
||||
return /*physical*/tensor_.dim() - numBatchDims();
|
||||
}
|
||||
|
||||
VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
|
||||
auto logical_ndim = numLogicalDims();
|
||||
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
|
||||
VmapDimVector result;
|
||||
result.reserve(logical_ndim);
|
||||
for (auto dim : logical_dims) {
|
||||
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
|
||||
auto logical_ndim = numLogicalDims();
|
||||
return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
|
||||
}
|
||||
|
||||
VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
|
||||
VmapDimVector result;
|
||||
result.reserve(logical_shape.size() + numBatchDims());
|
||||
auto tensor_sizes = tensor_.sizes();
|
||||
result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
|
||||
result.insert(result.end(), logical_shape.begin(), logical_shape.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
|
||||
BatchDims bdims;
|
||||
int64_t dim = 0;
|
||||
for (int64_t level = 0; level < kVmapNumLevels; level++) {
|
||||
if (!levels_bitset[level]) {
|
||||
continue;
|
||||
}
|
||||
bdims.emplace_back(level, dim++);
|
||||
}
|
||||
return bdims;
|
||||
}
|
||||
|
||||
// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
|
||||
// with all vmapped dimensions permuted to the front, if they exist, and a
|
||||
// bitset of vmap levels that were present in the tensor.
|
||||
static std::pair<Tensor,std::bitset<kVmapNumLevels>>
|
||||
getPhysicalTensorAndLevels(const Tensor& self) {
|
||||
auto* batched = maybeGetBatchedImpl(self);
|
||||
if (batched) {
|
||||
return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
|
||||
}
|
||||
return {self, 0};
|
||||
}
|
||||
|
||||
// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
|
||||
// such that it has a batch dimension for each level in `requested_levels`
|
||||
// and `requested_example_dim` number of non-batch-dimensions.
|
||||
//
|
||||
// This function is useful in preparing physical views on tensors that can
|
||||
// then be passed into broadcasting operations. For example, when adding
|
||||
// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
|
||||
// batch dimensions, we must align the batch dimensions and non-batch-dimensions
|
||||
// (henceforth referred to as the "example" dimensions) separately to produce
|
||||
// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
|
||||
//
|
||||
// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
|
||||
//
|
||||
// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
|
||||
// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
|
||||
// level 1 and another extra dimension to pad the example dimensions to 2.
|
||||
//
|
||||
// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
|
||||
// returns a physical view of size [B0, B1, 2, 3]
|
||||
static Tensor alignBatchDimsAtFront(
|
||||
const Tensor& self,
|
||||
std::bitset<kVmapNumLevels> requested_levels,
|
||||
int64_t requested_example_dim) {
|
||||
Tensor physical_tensor;
|
||||
std::bitset<kVmapNumLevels> tensor_levels;
|
||||
std::tie(physical_tensor, tensor_levels) = getPhysicalTensorAndLevels(self);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(tensor_levels | requested_levels) == requested_levels,
|
||||
"`requested_levels` must be a superset of `self`'s levels");
|
||||
|
||||
auto physical_sizes = physical_tensor.sizes();
|
||||
|
||||
auto tensor_example_dim = physical_sizes.size() - /*num_batch_dims*/tensor_levels.count();
|
||||
TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
|
||||
|
||||
if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
|
||||
// Optimization: no need to do another view if the physical tensor is
|
||||
// already the correct shape
|
||||
return physical_tensor;
|
||||
}
|
||||
|
||||
VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
|
||||
|
||||
// align the example dims (non-bdims dims) first
|
||||
// aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
|
||||
std::copy(
|
||||
physical_sizes.rbegin(),
|
||||
physical_sizes.rbegin() + tensor_example_dim,
|
||||
aligned_sizes.rbegin());
|
||||
|
||||
// align the bdims
|
||||
int64_t level = 0;
|
||||
int64_t tensor_dim = 0;
|
||||
for (int64_t bdim = 0; bdim < requested_levels.count(); bdim++) {
|
||||
// Determine the level of the bdim
|
||||
while (!requested_levels[level]) level++;
|
||||
if (tensor_levels[level]) {
|
||||
aligned_sizes[bdim] = physical_sizes[tensor_dim++];
|
||||
}
|
||||
level++;
|
||||
}
|
||||
return physical_tensor.view(aligned_sizes);
|
||||
}
|
||||
|
||||
static Tensor moveDimToFrontAndExpand(Tensor tensor, optional<int64_t> dim, int64_t size) {
|
||||
if (dim) {
|
||||
tensor = tensor.movedim(*dim, 0);
|
||||
} else {
|
||||
tensor = tensor.unsqueeze(0);
|
||||
auto expanded_sizes = tensor.sizes().vec();
|
||||
expanded_sizes[0] = size;
|
||||
tensor = tensor.expand(expanded_sizes);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// The algorithm is as follows:
|
||||
// 1. Figure out what all of the collective levels in `logical_tensors` is.
|
||||
// 2. Move all batch dims to the front of the tensors and add extra dims
|
||||
// of size 1. At this point, every tensor will have a dimension for
|
||||
// each of the collective levels.
|
||||
// 3. Compute the batch_sizes.
|
||||
// 4. Expand each physical tensor so that they have output batch size equal
|
||||
// to `batch_sizes`
|
||||
VmapPhysicalViewVec
|
||||
MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
|
||||
auto cur_level = maybeCurrentDynamicLayer().value().layerId();
|
||||
auto bdim_size = -1;
|
||||
|
||||
// Figure out the batch size first
|
||||
for (const auto& logical_tensor : logical_tensors) {
|
||||
auto* batched = maybeGetBatchedImpl(logical_tensor);
|
||||
if (!batched) {
|
||||
continue;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(batched->bdims().size() == 1);
|
||||
if (batched->bdims().back().level() != cur_level) {
|
||||
continue;
|
||||
}
|
||||
bdim_size = batched->value().size(batched->bdims().back().dim());
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(bdim_size != -1);
|
||||
|
||||
std::bitset<kVmapNumLevels> levels;
|
||||
levels[cur_level] = 1;
|
||||
|
||||
VmapPhysicalViewVec result;
|
||||
for (const auto& logical_tensor : logical_tensors) {
|
||||
auto* batched = maybeGetBatchedImpl(logical_tensor);
|
||||
if (!batched || (batched->bdims().back().level() != cur_level)) {
|
||||
// Unsqueeze dim 0, expand it to the correct shape
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
auto value = moveDimToFrontAndExpand(logical_tensor, {}, bdim_size);
|
||||
result.emplace_back(std::move(value), levels);
|
||||
continue;
|
||||
}
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
auto physical = batched->value();
|
||||
auto value = moveDimToFrontAndExpand(physical, batched->bdims().back().dim(), bdim_size);
|
||||
result.emplace_back(std::move(value), levels);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::pair<std::bitset<kVmapNumLevels>,int64_t>
|
||||
getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
|
||||
TORCH_INTERNAL_ASSERT(logical_tensors.size() > 0);
|
||||
std::bitset<kVmapNumLevels> levels;
|
||||
int64_t largest_logical_dim = -1;
|
||||
for (const auto& tensor : logical_tensors) {
|
||||
auto* batched = maybeGetBatchedImpl(tensor);
|
||||
if (batched) {
|
||||
levels = levels | createVmapLevelsBitset(batched->bdims());
|
||||
}
|
||||
auto tensor_logical_dim = /*logical dim*/tensor.dim();
|
||||
if (tensor_logical_dim > largest_logical_dim) {
|
||||
largest_logical_dim = tensor_logical_dim;
|
||||
}
|
||||
}
|
||||
return { levels, largest_logical_dim };
|
||||
}
|
||||
|
||||
static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, optional<int64_t> dim, int64_t example_ndim) {
|
||||
if (dim) {
|
||||
tensor = tensor.movedim(*dim, 0);
|
||||
} else {
|
||||
tensor = tensor.unsqueeze(0);
|
||||
}
|
||||
auto ndim = tensor.dim() - 1;
|
||||
for (int64_t i = 0; i < example_ndim - ndim; i++) {
|
||||
tensor = tensor.unsqueeze(1);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
|
||||
auto cur_level = maybeCurrentDynamicLayer().value().layerId();
|
||||
auto bdim_size = -1;
|
||||
|
||||
// Figure out the batch size first
|
||||
for (const auto& logical_tensor : logical_tensors) {
|
||||
auto* batched = maybeGetBatchedImpl(logical_tensor);
|
||||
if (!batched || (batched->bdims().back().level() != cur_level)) {
|
||||
continue;
|
||||
}
|
||||
bdim_size = batched->value().size(batched->bdims().back().dim());
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(bdim_size != -1);
|
||||
|
||||
std::bitset<kVmapNumLevels> levels;
|
||||
levels[cur_level] = 1;
|
||||
|
||||
// figure out the example ndim
|
||||
int64_t max_example_dim = -1;
|
||||
for (const auto& logical_tensor : logical_tensors) {
|
||||
max_example_dim = std::max(logical_tensor.dim(), max_example_dim);
|
||||
}
|
||||
|
||||
VmapPhysicalViewVec result;
|
||||
for (const auto& logical_tensor : logical_tensors) {
|
||||
auto* batched = maybeGetBatchedImpl(logical_tensor);
|
||||
if (!batched || (batched->bdims().back().level() != cur_level)) {
|
||||
// Unsqueeze dim 0, expand it to the correct shape
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
auto value = moveDimToFrontAndUnsqueeze(logical_tensor, {}, max_example_dim);
|
||||
result.emplace_back(std::move(value), levels);
|
||||
continue;
|
||||
}
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
auto physical = batched->value();
|
||||
auto value = moveDimToFrontAndUnsqueeze(physical, batched->bdims().back().dim(), max_example_dim);
|
||||
result.emplace_back(std::move(value), levels);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
|
||||
return VmapPhysicalToLogicalMap(levels_);
|
||||
}
|
||||
|
||||
Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
|
||||
return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
|
||||
}
|
||||
|
||||
void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
|
||||
for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) {
|
||||
physical_tensors[idx] = apply(physical_tensors[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace at
|
177
functorch/functorch/csrc/VmapTransforms.h
Normal file
177
functorch/functorch/csrc/VmapTransforms.h
Normal file
@ -0,0 +1,177 @@
|
||||
#pragma once
|
||||
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
// This file contains abstractions used for transforming *logical* vmap arguments
|
||||
// into *physical* arguments. (Keep reading for definitions of these terms).
|
||||
|
||||
// NOTE: [Logical vs physical args]
|
||||
// Consider the following vmap.
|
||||
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
|
||||
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
|
||||
// with batch dims 0 and 2:
|
||||
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
|
||||
//
|
||||
// We say the *logical* view of the tensor has size [3] -- tensors inside
|
||||
// `func` appear to have size [3].
|
||||
// However, the *physical* underlying tensor (the one passed to vmap) has size
|
||||
// [2, 3, 4].
|
||||
//
|
||||
// This notion of logical vs physical also extends to non-tensor arguments.
|
||||
// Consider the previous tensor; let's assume the user called
|
||||
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
|
||||
// dimension they are reducing over is dim 0 but the physical dim is dim 1
|
||||
// (the first non-batch dimension)
|
||||
|
||||
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
|
||||
struct VmapPhysicalView;
|
||||
|
||||
// Most PyTorch operators take 4 or fewer inputs.
|
||||
constexpr int64_t kVmapTransformStaticInputSize = 4;
|
||||
using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
|
||||
|
||||
// Pytorch generally advertises good performance for <= 5 dims.
|
||||
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
|
||||
// dimensions to get 8. Adjust this number as necessary
|
||||
constexpr int64_t kVmapStaticDimVecSize = 8;
|
||||
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
|
||||
|
||||
// NOTE: [What is an VmapTransform?]
|
||||
// An *VmapTransform* converts logical views of tensors to physical views.
|
||||
//
|
||||
// Batching rules use VmapTransforms to convert logical arguments to
|
||||
// physical arguments, then call one or more at:: operator that handles the
|
||||
// physical arguments, and then converts the physical result back to a logical
|
||||
// argument.
|
||||
|
||||
// VmapTransform for operators that take tensors with multiple batch dims.
|
||||
// Given one or more logical views on Tensors, `logicalToPhysical`
|
||||
// permutes all of the batch dims to the front of the tensor, aligns
|
||||
// and expands the batch dims to match each other (according to their `level`),
|
||||
// and returns a VmapPhysicalView on the tensor(s).
|
||||
struct TORCH_API MultiBatchVmapTransform {
|
||||
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
|
||||
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
|
||||
};
|
||||
|
||||
// VmapTransform for operators that broadcast all inputs.
|
||||
// Given some logical views on Tensors, `logicalToPhysical`:
|
||||
// - permutes all of the batch dims to the front of the tensors
|
||||
// - aligns all the batch dims to the collective levels of all of the tensors.
|
||||
// If a tensor does not have a batch dim for a vmap level, then it receives
|
||||
// a size-one dimension for said level.
|
||||
// - aligns the non-batch dims to have the same dimensionality, adding extra
|
||||
// size-1 dimensions in between the batch dimensions and the non-batch dimensions
|
||||
// so that the batch dimensions are lined up from the right.
|
||||
//
|
||||
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
|
||||
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
|
||||
// of size (B, 1, 2) and (B, 3, 2).
|
||||
//
|
||||
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
|
||||
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
|
||||
// actually *need* to return a tensor of size (1, 2) for the second tensor
|
||||
// because the broadcasting operation takes care of that for us, but we do
|
||||
// it anyways to keep things simple.
|
||||
struct TORCH_API BroadcastingVmapTransform {
|
||||
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
|
||||
};
|
||||
|
||||
// Forward declared, if you're reading this file head to toe, don't worry about
|
||||
// it yet.
|
||||
struct VmapPhysicalToLogicalMap;
|
||||
|
||||
// NOTE: [What is a VmapPhysicalView?]
|
||||
// VmapPhysicalView represents a physical view on a Tensor.
|
||||
//
|
||||
// One can use it to further convert logical dimension indices, logical shapes,
|
||||
// and more to their physical variants, or convert a new (physical) tensor into
|
||||
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
|
||||
//
|
||||
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
|
||||
// the front and some levels that correspond to said batch dimensions.
|
||||
//
|
||||
// The levels bitset specifies which vmap levels correspond to the batch
|
||||
// dimensions at the front of the tensor. In particular, the number of set bits
|
||||
// corresponds to the number of batch dimensions on `tensor` and the rightmost
|
||||
// bit of `levels` specifies the maximum number of nested vmaps we are in at
|
||||
// this point in time.
|
||||
// For example, given:
|
||||
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
|
||||
//
|
||||
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
|
||||
// than or equal to 3.
|
||||
// bitset: 010100
|
||||
// ^
|
||||
// |
|
||||
// levels: 012345
|
||||
struct TORCH_API VmapPhysicalView {
|
||||
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
|
||||
: levels_(levels), tensor_(tensor) {
|
||||
// TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
|
||||
}
|
||||
|
||||
Tensor& tensor() { return tensor_; }
|
||||
const Tensor& tensor() const { return tensor_; }
|
||||
|
||||
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
|
||||
//
|
||||
// For example, given:
|
||||
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
|
||||
//
|
||||
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
|
||||
// This is because the size of levels tell us that the first two dimensions
|
||||
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
||||
// a physical dim of `n + 2`.
|
||||
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
|
||||
int64_t getPhysicalDim(int64_t logical_dim) const;
|
||||
|
||||
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
||||
// mapping a physical tensor to a new logical tensor (BatchedTensor)
|
||||
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
|
||||
|
||||
// Maps a logical shape to a physical shape by pre-pending the batch
|
||||
// sizes to the logical shape.
|
||||
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
|
||||
|
||||
int64_t numBatchDims() const;
|
||||
|
||||
private:
|
||||
int64_t numLogicalDims() const;
|
||||
|
||||
std::bitset<kVmapNumLevels> levels_;
|
||||
Tensor tensor_;
|
||||
};
|
||||
|
||||
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
|
||||
// to a logical one (BatchedTensor). It holds some levels that are used to do the
|
||||
// mapping and assumes that the batch dimensions in the physical tensor all
|
||||
// occur at the front of the tensor.
|
||||
struct TORCH_API VmapPhysicalToLogicalMap {
|
||||
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
|
||||
|
||||
// Maps a physical tensor to a new logical tensor (BatchedTensor).
|
||||
// Assumes that all of the "batch dimensions" are at the front
|
||||
// of the physical tensor. For example, given:
|
||||
// - x = rank-4 Tensor with size 2, 3, 5, 7
|
||||
// - levels = (2, 4)
|
||||
// Returns:
|
||||
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
|
||||
Tensor apply(const Tensor& physical_tensor) const;
|
||||
|
||||
// Given a vector of physical tensors,
|
||||
// 1. maps each tensor to a new logical tensor. Assumes that all of the
|
||||
// "batch dimensions" are at the front of the physical tensors.
|
||||
// 2. stores the new logical tensors back into the passed-in vector. This is
|
||||
// to avoid additional dynamic allocations.
|
||||
void applyInplace(std::vector<Tensor>& physical_tensors) const;
|
||||
|
||||
std::bitset<kVmapNumLevels> levels_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
} // namespace at
|
176
functorch/functorch/csrc/init.cpp
Normal file
176
functorch/functorch/csrc/init.cpp
Normal file
@ -0,0 +1,176 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
|
||||
#include <functorch/csrc/TensorWrapper.h>
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/VmapMode.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
static bool has_level(const Tensor& self, int64_t level) {
|
||||
const auto* batched = maybeGetBatchedImpl(self);
|
||||
if (!batched) {
|
||||
return false;
|
||||
}
|
||||
auto bdims = batched->bdims();
|
||||
auto* it = std::find_if(bdims.begin(), bdims.end(), [&](const BatchDim& bdim) {
|
||||
return bdim.level() == level;
|
||||
});
|
||||
return it != bdims.end();
|
||||
}
|
||||
|
||||
Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
|
||||
return addBatchDim(self, level, batch_dim);
|
||||
}
|
||||
|
||||
static std::pair<Tensor,int64_t> remove_existing_batch_dim(
|
||||
const BatchedTensorImpl* batched, int64_t level) {
|
||||
auto bdims = batched->bdims();
|
||||
if (bdims.size() == 1) {
|
||||
TORCH_INTERNAL_ASSERT(bdims[0].level() == level);
|
||||
return std::make_pair(batched->value(), bdims[0].dim());
|
||||
}
|
||||
BatchDims new_bdims;
|
||||
int64_t newly_exposed_physical_dim = -1;
|
||||
new_bdims.reserve(bdims.size() - 1);
|
||||
for (const auto& bdim : bdims) {
|
||||
if (bdim.level() == level) {
|
||||
newly_exposed_physical_dim = bdim.dim();
|
||||
} else {
|
||||
new_bdims.push_back(bdim);
|
||||
}
|
||||
}
|
||||
// Because a BatchDim with level `level` must exist inside `batched,
|
||||
// we should have found a `newly_exposed_logical_dim`.
|
||||
TORCH_INTERNAL_ASSERT(newly_exposed_physical_dim != -1);
|
||||
int64_t num_batch_dims_before_newly_exposed_physical_dim = std::count_if(
|
||||
new_bdims.begin(), new_bdims.end(),
|
||||
[&](const BatchDim& bdim) {
|
||||
return bdim.dim() < newly_exposed_physical_dim;
|
||||
});
|
||||
int64_t newly_exposed_logical_dim =
|
||||
newly_exposed_physical_dim - num_batch_dims_before_newly_exposed_physical_dim;
|
||||
auto result_tensor = makeBatched(batched->value(), std::move(new_bdims));
|
||||
return std::make_pair(std::move(result_tensor), newly_exposed_logical_dim);
|
||||
}
|
||||
|
||||
// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
|
||||
// while preserving the order of other existing dimensions.
|
||||
// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
|
||||
// When we do, replace the following with it.
|
||||
static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) {
|
||||
auto logical_dim = self.dim();
|
||||
src = maybe_wrap_dim(src, logical_dim);
|
||||
dst = maybe_wrap_dim(dst, logical_dim);
|
||||
if (src == dst) {
|
||||
return self;
|
||||
}
|
||||
VmapDimVector permutation;
|
||||
permutation.reserve(logical_dim);
|
||||
for (int64_t dim = 0; dim < logical_dim; dim++) {
|
||||
if (dim == src) {
|
||||
continue;
|
||||
}
|
||||
permutation.push_back(dim);
|
||||
}
|
||||
permutation.insert(permutation.begin() + dst, src);
|
||||
return self.permute(permutation);
|
||||
}
|
||||
|
||||
// Removes the batch dim with level `level` from `self`. If this causes the
|
||||
// last batch dim to be removed from a BatchedTensor, then this returns a
|
||||
// regular Tensor.
|
||||
//
|
||||
// If the `level` of the batch dim to remove does not exist in `self`, then we
|
||||
// add the batch dim in. This can happen if `self` didn't interact with a tensor
|
||||
// inside the vmap level, for example,
|
||||
// self = torch.randn(3)
|
||||
// y = torch.randn(5)
|
||||
// out = vmap(lambda x: vmap(lambda y: x)(y))(self)
|
||||
// assert out.shape == (3, 5)
|
||||
// Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
|
||||
// corresponding to the *outer* vmap level and it doesn't have any dimensions that
|
||||
// correspond to the inner vmap level so we need to create one for the user.
|
||||
//
|
||||
// `out_dim` controls where we should put the batch dimension in the output tensor.
|
||||
Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) {
|
||||
if (!has_level(self, level)) {
|
||||
auto self_sizes = self.sizes();
|
||||
VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
|
||||
expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
|
||||
return self.expand(expanded_sizes);
|
||||
}
|
||||
|
||||
// Must be batched if has_level(self, /*any_level*/)
|
||||
const auto* batched = maybeGetBatchedImpl(self);
|
||||
TORCH_INTERNAL_ASSERT(batched != nullptr);
|
||||
|
||||
Tensor self_without_bdim;
|
||||
int64_t newly_exposed_logical_dim;
|
||||
std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level);
|
||||
return _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
|
||||
}
|
||||
|
||||
Tensor _wrap_for_grad(const Tensor& self, int64_t level) {
|
||||
// NB: different behavior inside??
|
||||
// return self;
|
||||
// TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self));
|
||||
// TORCH_INTERNAL_ASSERT(self.has_storage());
|
||||
return makeTensorWrapper(self, level);
|
||||
}
|
||||
|
||||
Tensor _unwrap_for_grad(const Tensor& self, int64_t level) {
|
||||
auto* result = maybeGetTensorWrapper(self);
|
||||
if (!result) {
|
||||
return self;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(result->level().has_value());
|
||||
if (result->level() == level) {
|
||||
return result->value();
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
int64_t dlevel(const Tensor& tensor) {
|
||||
auto* wrapped = maybeGetTensorWrapper(tensor);
|
||||
if (!wrapped) {
|
||||
return 0;
|
||||
}
|
||||
if (!wrapped->is_alive()) {
|
||||
return -1;
|
||||
}
|
||||
return wrapped->level().value();
|
||||
}
|
||||
|
||||
bool dump_tensor(const Tensor& self) {
|
||||
dumpTensorCout(self);
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t _grad_increment_nesting() {
|
||||
return initAndPushDynamicLayer(at::DispatchKey::Autograd);
|
||||
}
|
||||
|
||||
int64_t _grad_decrement_nesting() {
|
||||
return popDynamicLayerAndDeleteMetadata().layerId();
|
||||
}
|
||||
|
||||
|
||||
} // namespace functorch
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim");
|
||||
m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim");
|
||||
m.def("_vmapmode_increment_nesting", &at::functorch::impl::VmapMode::increment_nesting, "add batch dim");
|
||||
m.def("_vmapmode_decrement_nesting", &at::functorch::impl::VmapMode::decrement_nesting, "remove batch dim");
|
||||
m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim");
|
||||
m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim");
|
||||
m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim");
|
||||
m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
|
||||
m.def("dlevel", &at::functorch::dlevel, "add batch dim");
|
||||
m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");
|
||||
}
|
71
functorch/setup.py
Normal file
71
functorch/setup.py
Normal file
@ -0,0 +1,71 @@
|
||||
import distutils
|
||||
import shutil
|
||||
import glob
|
||||
import os
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import (
|
||||
CppExtension,
|
||||
BuildExtension,
|
||||
)
|
||||
|
||||
|
||||
# class clean(distutils.command.clean.clean):
|
||||
# def run(self):
|
||||
# with open(".gitignore", "r") as f:
|
||||
# ignores = f.read()
|
||||
# for wildcard in filter(None, ignores.split("\n")):
|
||||
# for filename in glob.glob(wildcard):
|
||||
# try:
|
||||
# os.remove(filename)
|
||||
# except OSError:
|
||||
# shutil.rmtree(filename, ignore_errors=True)
|
||||
#
|
||||
# # It's an old-style class in Python 2.7...
|
||||
# distutils.command.clean.clean.run(self)
|
||||
|
||||
|
||||
def get_extensions():
|
||||
extension = CppExtension
|
||||
|
||||
define_macros = []
|
||||
|
||||
extra_link_args = []
|
||||
extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
|
||||
if int(os.environ.get("DEBUG", 0)):
|
||||
extra_compile_args = {
|
||||
"cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]}
|
||||
extra_link_args = ["-O0", "-g"]
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
extensions_dir = os.path.join(this_dir, "functorch", "csrc")
|
||||
|
||||
extension_sources = set(
|
||||
os.path.join(extensions_dir, p)
|
||||
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
||||
)
|
||||
sources = list(extension_sources)
|
||||
include_dirs = [extensions_dir]
|
||||
|
||||
ext_modules = [
|
||||
extension(
|
||||
"functorch._C",
|
||||
sources,
|
||||
include_dirs=[this_dir],
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=extra_link_args,
|
||||
)
|
||||
]
|
||||
|
||||
return ext_modules
|
||||
|
||||
|
||||
setup(
|
||||
name='functorch',
|
||||
url="https://github.com/zou3519/functorch",
|
||||
packages=find_packages(),
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
# "clean": clean,
|
||||
"build_ext": BuildExtension
|
||||
})
|
438
functorch/test/test_eager_transforms.py
Normal file
438
functorch/test/test_eager_transforms.py
Normal file
@ -0,0 +1,438 @@
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import unittest
|
||||
import functools
|
||||
import itertools
|
||||
import warnings
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
skipCUDAIfNoMagma
|
||||
import types
|
||||
from functools import partial
|
||||
|
||||
import functorch
|
||||
from functorch import grad, vjp, vmap, make_functional, jacrev
|
||||
|
||||
|
||||
class TestGradTransform(TestCase):
|
||||
def test_primitive(self):
|
||||
x = torch.randn([])
|
||||
result = grad(torch.sin)(x)
|
||||
self.assertEqual(result, torch.cos(x))
|
||||
|
||||
def test_composite_simple(self):
|
||||
x = torch.randn(2, 3, 4)
|
||||
result = grad(lambda x: torch.flatten(x).sum())(x)
|
||||
self.assertEqual(result, torch.ones_like(x))
|
||||
|
||||
def test_composite_complicated(self):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 5)
|
||||
|
||||
def foo(x, y):
|
||||
result = x @ y
|
||||
return result.sum()
|
||||
|
||||
result = grad(foo)(x, y)
|
||||
|
||||
x.requires_grad_()
|
||||
out = foo(x, y)
|
||||
expected, = torch.autograd.grad(out, x)
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_composite_two_ops(self):
|
||||
N, C = 2, 5
|
||||
y = torch.randn(N, C)
|
||||
targets = torch.randint(0, C, (N,))
|
||||
|
||||
def foo(y, targets):
|
||||
return F.cross_entropy(y, targets)
|
||||
|
||||
result = grad(foo)(y, targets)
|
||||
|
||||
y.requires_grad_()
|
||||
expected, = torch.autograd.grad(foo(y, targets), y)
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def _test_attributes(self, get_attr_lambda):
|
||||
x = torch.randn(2, 3, 5, dtype=torch.double)
|
||||
expected = get_attr_lambda(x)
|
||||
|
||||
def foo(x):
|
||||
self.assertEqual(get_attr_lambda(x), expected)
|
||||
return x.sum()
|
||||
|
||||
grad(foo)(x)
|
||||
|
||||
def test_shape(self):
|
||||
self._test_attributes(lambda x: x.shape)
|
||||
|
||||
def test_dtype(self):
|
||||
self._test_attributes(lambda x: x.dtype)
|
||||
|
||||
def test_is_cuda(self):
|
||||
self._test_attributes(lambda x: x.is_cuda)
|
||||
|
||||
def test_numel(self):
|
||||
self._test_attributes(lambda x: x.numel())
|
||||
|
||||
def test_inplace(self):
|
||||
x = torch.randn([])
|
||||
|
||||
def foo(x):
|
||||
return x.clone().sin_()
|
||||
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(result, x.cos())
|
||||
|
||||
def test_inplace_on_view(self):
|
||||
x = torch.randn(3)
|
||||
|
||||
def foo(x):
|
||||
y = x.clone()
|
||||
y0 = y[0]
|
||||
y0.sin_()
|
||||
return y.sum()
|
||||
|
||||
result = grad(foo)(x)
|
||||
|
||||
x.requires_grad_()
|
||||
out = foo(x)
|
||||
expected, = torch.autograd.grad(out, x)
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_inplace_on_view_base(self):
|
||||
x = torch.randn(3)
|
||||
|
||||
def foo(x):
|
||||
y = x.clone()
|
||||
y0 = y[0]
|
||||
y.sin_()
|
||||
return y0
|
||||
|
||||
result = grad(foo)(x)
|
||||
|
||||
x.requires_grad_()
|
||||
out = foo(x)
|
||||
expected, = torch.autograd.grad(out, x)
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_nesting_simple(self):
|
||||
x = torch.randn([])
|
||||
result = grad(grad(torch.sin))(x)
|
||||
self.assertEqual(result, -torch.sin(x))
|
||||
|
||||
def test_escaped_wrappers_are_marked_as_dead(self):
|
||||
x = torch.randn([])
|
||||
escaped = []
|
||||
def foo(x):
|
||||
y = x.sin()
|
||||
escaped.append(y)
|
||||
return y
|
||||
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(escaped[0].dlevel(), -1)
|
||||
|
||||
def test_escaped_wrappers_are_ignored(self):
|
||||
x = torch.randn([])
|
||||
escaped = []
|
||||
def foo(x):
|
||||
y = x.sin()
|
||||
escaped.append(y)
|
||||
return y
|
||||
|
||||
result = grad(foo)(x)
|
||||
|
||||
something = escaped[0].sum()
|
||||
self.assertEqual(something.dlevel(), 0)
|
||||
self.assertEqual(something, x.sin().sum())
|
||||
|
||||
def test_vjp(self):
|
||||
x = torch.randn([])
|
||||
out, vjp_fn = vjp(torch.sin, x)
|
||||
self.assertEqual(out, x.sin())
|
||||
|
||||
v = torch.randn([])
|
||||
result, = vjp_fn(v)
|
||||
self.assertEqual(result, v * x.cos())
|
||||
|
||||
def test_composed_with_autograd(self):
|
||||
x = torch.randn([], requires_grad=True)
|
||||
|
||||
y = grad(torch.sin)(x)
|
||||
result, = torch.autograd.grad(y, x)
|
||||
self.assertEqual(result, -x.sin())
|
||||
|
||||
def test_grad_of_vjp_composition(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
|
||||
def foo(x, y):
|
||||
out, vjp_fn = vjp(torch.sin, x)
|
||||
return grad(lambda y: vjp_fn(y)[0])(y)
|
||||
|
||||
result = foo(x, y)
|
||||
expected = x.cos()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_vjp_of_grad_composition(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
|
||||
def foo(x, y):
|
||||
out, vjp_fn = vjp(grad(torch.sin), x)
|
||||
return vjp_fn(y)[0]
|
||||
|
||||
result = foo(x, y)
|
||||
expected = -y * x.sin()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_grad_of_vjp_of_grad_composition(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
|
||||
def foo(x, y):
|
||||
df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
|
||||
return grad(lambda y: vjp_fn(y)[0])(y)
|
||||
|
||||
result = foo(x, y)
|
||||
expected = x.cos()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_views(self):
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = torch.randn([], requires_grad=True)
|
||||
|
||||
def silly_sin(x):
|
||||
x = x.view([])
|
||||
x = x.sin()
|
||||
return x
|
||||
|
||||
def foo(x, y):
|
||||
z1 = grad(silly_sin)(x)
|
||||
z2 = torch.cos(y)
|
||||
return z1 + z2
|
||||
|
||||
result = foo(x, y)
|
||||
grads = torch.autograd.grad(result, [x, y])
|
||||
self.assertEqual(grads[0], -x.sin())
|
||||
self.assertEqual(grads[1], -y.sin())
|
||||
|
||||
def test_view_inplace_simple(self):
|
||||
def foo(x):
|
||||
x = x.clone()
|
||||
x.view([]).sin_()
|
||||
return x
|
||||
|
||||
x = torch.randn([], requires_grad=True)
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(result, x.cos())
|
||||
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self):
|
||||
def compute_loss(weight, x, t):
|
||||
x = x.mm(weight)
|
||||
y = x.squeeze_(0)
|
||||
return (y - t).sum()
|
||||
|
||||
weight = torch.randn(16, 2)
|
||||
x = torch.randn(64, 1, 16)
|
||||
t = torch.randn(64, 2)
|
||||
result = vmap(partial(grad(compute_loss), weight))(x, t)
|
||||
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
|
||||
expected = torch.stack(expected)
|
||||
# TODO: Check if the rtol is a problem
|
||||
self.assertEqual(result, expected, atol=0, rtol=5e-4)
|
||||
|
||||
def test_new_zeros_materializes_tensor(self):
|
||||
N = 3
|
||||
C = 5
|
||||
|
||||
def foo(x, y):
|
||||
result = x.new_zeros((C,))
|
||||
result.copy_(y)
|
||||
return result.sum()
|
||||
|
||||
x = torch.randn(N)
|
||||
y = torch.randn(N, C)
|
||||
result = vmap(grad(foo))(x, y)
|
||||
|
||||
def test_per_sample_grads_simple(self):
|
||||
def compute_loss(weight, x, t):
|
||||
y = x @ weight
|
||||
return ((y - t) ** 2).sum()
|
||||
|
||||
weight = torch.randn(16, 2)
|
||||
x = torch.randn(64, 16)
|
||||
t = torch.randn(64, 2)
|
||||
result = vmap(partial(grad(compute_loss), weight))(x, t)
|
||||
expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
|
||||
expected = torch.stack(expected)
|
||||
# TODO: Check if the rtol is a problem
|
||||
self.assertEqual(result, expected, atol=0, rtol=5e-4)
|
||||
|
||||
def test_per_sample_grads_embeddingnet(self):
|
||||
class SampleNet(nn.Module):
|
||||
def __init__(self, vocab_size: int):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(vocab_size, 16)
|
||||
self.fc1 = nn.Linear(16, 16)
|
||||
self.fc2 = nn.Linear(16, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.emb(x)
|
||||
x = torch.transpose(x, -1, -2)
|
||||
x = torch.mean(x, -1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def name(self):
|
||||
return "SampleNet"
|
||||
|
||||
# Create our inputs...
|
||||
vocab_size = 1000
|
||||
batch_shape = [64]
|
||||
words_per_sentence = 5
|
||||
data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence))
|
||||
targets = torch.randint(0, 1, (*batch_shape,))
|
||||
|
||||
# Construct our module
|
||||
net = SampleNet(vocab_size)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
params = dict(net.named_parameters())
|
||||
weights, net_func, _ = make_functional(net)
|
||||
|
||||
def compute_loss(weights, data, target):
|
||||
output = net_func(weights, (data,))
|
||||
result = criterion(output, target)
|
||||
return result
|
||||
|
||||
expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
|
||||
expected = zip(*expected)
|
||||
expected = tuple(torch.stack(shards) for shards in expected)
|
||||
|
||||
result = vmap(partial(grad(compute_loss), weights))(data, targets)
|
||||
for r, e in zip(result, expected):
|
||||
# TODO: Check if the rtol is a problem
|
||||
self.assertEqual(r, e, atol=0, rtol=1e-4)
|
||||
|
||||
class TestJacrev(TestCase):
|
||||
def test_simple(self):
|
||||
x = torch.randn(3)
|
||||
y = jacrev(torch.sin)(x)
|
||||
expected = torch.diagflat(x.cos())
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
def test_simple_not_flat(self):
|
||||
x = torch.randn(2, 3)
|
||||
y = jacrev(torch.sin)(x)
|
||||
expected = torch.diagflat(x.view(-1).cos())
|
||||
expected = expected.view(2, 3, 2, 3)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
def test_vmap_on_jacrev_simple(self):
|
||||
x = torch.randn(2, 3)
|
||||
y = vmap(jacrev(torch.sin))(x)
|
||||
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
def test_hessian_simple(self):
|
||||
def foo(x):
|
||||
return x.sin().sum()
|
||||
|
||||
x = torch.randn(3)
|
||||
y = jacrev(jacrev(foo))(x)
|
||||
expected = torch.diagflat(-x.sin())
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
class TestComposability(TestCase):
|
||||
def test_grad_grad(self):
|
||||
x = torch.randn([])
|
||||
y = grad(grad(torch.sin))(x)
|
||||
self.assertEqual(y, -x.sin())
|
||||
|
||||
def test_grad_vmap(self):
|
||||
def foo(x):
|
||||
y = vmap(torch.sin)(x)
|
||||
return y.sum()
|
||||
|
||||
x = torch.randn(3)
|
||||
y = grad(foo)(x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
def test_grad_vjp(self):
|
||||
x = torch.randn(3)
|
||||
|
||||
def foo(x):
|
||||
_, vjp_fn = vjp(torch.sin, x)
|
||||
return vjp_fn(x)[0].sum()
|
||||
|
||||
y = grad(foo)(x)
|
||||
expected = grad(lambda x: (x * x.cos()).sum())(x)
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
def test_vmap_grad(self):
|
||||
x = torch.randn(3)
|
||||
y = vmap(grad(torch.sin))(x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
def test_vmap_vmap(self):
|
||||
x = torch.randn(2, 3)
|
||||
y = vmap(vmap(torch.sin))(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
def test_vmap_vjp(self):
|
||||
x = torch.randn(3)
|
||||
_, vjp_fn = vjp(torch.sin, x)
|
||||
|
||||
def foo(x):
|
||||
_, vjp_fn = vjp(torch.sin, x)
|
||||
return vjp_fn(x)
|
||||
|
||||
y = vmap(foo)(x)
|
||||
self.assertEqual(y, vjp_fn(x))
|
||||
|
||||
xs = torch.randn(5, 3)
|
||||
expected = torch.stack([vjp_fn(x)[0] for x in xs])
|
||||
self.assertEqual(vmap(lambda x: vjp_fn(x)[0])(xs), expected)
|
||||
|
||||
def test_vjp_grad(self):
|
||||
x = torch.randn([])
|
||||
y, vjp_fn = vjp(grad(torch.sin), x)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
v = torch.randn([])
|
||||
self.assertEqual(vjp_fn(v)[0], -x.sin() * v)
|
||||
|
||||
def test_vjp_vmap(self):
|
||||
x = torch.randn(3)
|
||||
y, vjp_fn = vjp(vmap(torch.sin), x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
v = torch.randn(3)
|
||||
self.assertEqual(vjp_fn(v)[0], x.cos() * v)
|
||||
|
||||
def test_vjp_vjp(self):
|
||||
x = torch.randn(3)
|
||||
y, vjp_fn = vjp(torch.sin, x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x)
|
||||
self.assertEqual(y, x * x.cos())
|
||||
|
||||
y = vjp_fn(x)[0]
|
||||
# Honestly IDK what the result here is... but at least it runs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
2516
functorch/test/test_vmap.py
Normal file
2516
functorch/test/test_vmap.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user