[functorch] Fix up the cifar10_transforms example

This commit is contained in:
Richard Zou
2021-04-30 12:44:15 -07:00
committed by Jon Janzen
parent dfeb7898f2
commit 98df806b95
8 changed files with 12 additions and 1928 deletions

View File

@ -25,9 +25,8 @@ from torchvision.datasets import CIFAR10
from tqdm import tqdm
from make_functional import make_functional, load_weights
from torch.eager_transforms import vmap, grad_with_value
from functools import partial
# from resnet import resnet18
from functorch import vmap, grad_and_value
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
@ -104,7 +103,7 @@ def train(args, model, train_loader, optimizer, epoch, device):
# We want to extract some intermediate
# values from the computation (i.e. the loss and output).
#
# To extract the loss, we use the `grad_with_value` API, that returns the
# To extract the loss, we use the `grad_and_value` API, that returns the
# gradient of the weights w.r.t. the loss and the loss.
#
# To extract the output, we use the `has_aux=True` flag.
@ -112,13 +111,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
# where the first is to be differentiated and the second "auxiliary value"
# is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
# the loss, and the auxiliary value.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
def packed(weights, images, target):
grads, loss, output = grads_loss_output(weights, images, target)
result = tuple([*grads, loss, output])
return result
result = vmap(partial(packed, weights))(images, target)
sample_grads, sample_loss, output, = result[:-2], result[-2], result[-1]
grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True)
sample_grads, (sample_loss, output) = \
vmap(grads_loss_output, (None, 0, 0))(weights, images, target)
loss = sample_loss.mean()
# `load_weights` is the inverse operation of make_functional. We put

View File

@ -1,438 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from functools import partial
from make_functional import make_functional, load_weights
# NB: The following might not exist depending on what you're using
from torch import vmap
from functional_utils import grad, grad_with_value
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
# TODO(rzou): does the group norm work correctly?
weights, func_model, descriptors = make_functional(model)
[weight.requires_grad_(False) for weight in weights]
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
target = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, target)
return loss, output.squeeze(0)
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean(0)
# Step 2: Clip the per-sample-grads and sum them to form grads
# TODO(rzou): Right now we just sum the grads. Instead we need to clip them.
for sample_grad, weight in zip(grads, weights):
weight_grad = sample_grad.sum(0)
weight.grad = weight_grad
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
[weight.requires_grad_(True) for weight in weights]
load_weights(model, descriptors, weights, as_params=True)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
if not args.disable_dp:
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
args.delta
)
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
)
else:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
default=256,
type=int,
metavar="N",
help="mini-batch size (default: 256), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -1,491 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch import vmap
from make_functional import make_functional, load_weights
from functional_utils import grad, grad_with_value
from functools import partial
# from resnet import resnet18
class CudaMemoryLeakCheck():
def __init__(self, name):
self.name = name
# initialize context & RNG to prevenikkt false positive detections
# when the test is the first to initialize those
from torch.testing._internal.common_cuda import initialize_cuda_context_rng
initialize_cuda_context_rng()
@staticmethod
def get_cuda_memory_usage():
# we don't need CUDA synchronize because the statistics are not tracked at
# actual freeing, but at when marking the block as free.
num_devices = torch.cuda.device_count()
import gc
gc.collect()
return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
def __enter__(self):
self.befores = self.get_cuda_memory_usage()
def __exit__(self, exec_type, exec_value, traceback):
# Don't check for leaks if an exception was thrown
if exec_type is not None:
return
afters = self.get_cuda_memory_usage()
for i, (before, after) in enumerate(zip(self.befores, afters)):
if after - before == 0:
continue
raise RuntimeError(f'{self.name} leaked {after-before} bytes')
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
return grads
def train(args, model, train_loader, optimizer, epoch, device):
use_prototype = False
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
weights, func_model, descriptors = make_functional(model)
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# grad_with_value(f) returns a function that returns (1) the grad and
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
# where the first is to be differentiated and the second is not to be
# differentiated and further adds a 3rd output.
#
# We need to use `grad_with_value(..., has_aux=True)` because we do
# some analyses on the returned loss and output.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
sample_grads, args.max_per_sample_grad_norm, args.sigma)
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
load_weights(model, descriptors, weights)
for weight_grad, weight in zip(grads, model.parameters()):
weight.grad = weight_grad.detach()
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -1,457 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch import vmap
from make_functional import make_functional, load_weights
from functional_utils import grad, grad_with_value
from functools import partial
# from resnet import resnet18
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
return grads
def train(args, model, train_loader, optimizer, epoch, device):
use_prototype = False
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
weights, func_model, descriptors = make_functional(model)
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# grad_with_value(f) returns a function that returns (1) the grad and
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
# where the first is to be differentiated and the second is not to be
# differentiated and further adds a 3rd output.
#
# We need to use `grad_with_value(..., has_aux=True)` because we do
# some analyses on the returned loss and output.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
sample_grads, args.max_per_sample_grad_norm, args.sigma)
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
load_weights(model, descriptors, weights)
for weight_grad, weight in zip(grads, model.parameters()):
weight.grad = weight_grad.detach()
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -1,456 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Runs CIFAR10 training with differential privacy.
"""
import argparse
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torch.utils.tensorboard as tensorboard
import torchvision.models as models
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils import stats
from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch import vmap
from make_functional import make_functional, load_weights
from functional_utils import grad, grad_with_value
from functools import partial
# from resnet import resnet18
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
def accuracy(preds, labels):
return (preds == labels).mean()
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
return grads
def train(args, model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
# Step 1: compute per-sample-grads
weights, func_model, descriptors = make_functional(model)
def compute_loss_and_output(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
output = func_model(weights, (images,))
loss = criterion(output, targets)
return loss, output.squeeze(0)
# grad_with_value(f) returns a function that returns (1) the grad and
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
# where the first is to be differentiated and the second is not to be
# differentiated and further adds a 3rd output.
#
# We need to use `grad_with_value(..., has_aux=True)` because we do
# some analyses on the returned loss and output.
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
loss = sample_loss.mean()
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
grads = clip_and_accumulate_and_add_noise(
sample_grads, args.max_per_sample_grad_norm, args.sigma)
# `load_weights` is the inverse operation of make_functional. We put
# things back into a model so that we can directly apply optimizers.
# TODO(rzou): this might not be necessary, optimizers just take
# the params straight up.
load_weights(model, descriptors, weights)
for weight_grad, weight in zip(grads, model.parameters()):
weight.grad = weight_grad.detach()
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
losses.append(loss.item())
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)
stats.update(stats.StatType.TRAIN, acc1=acc1)
# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
else:
optimizer.virtual_step()
if i % args.print_freq == 0:
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc):.6f} "
)
def test(args, model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in tqdm(test_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc1 = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc1)
top1_avg = np.mean(top1_acc)
stats.update(stats.StatType.TEST, acc1=top1_avg)
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
return np.mean(top1_acc)
def main():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--epochs",
default=90,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch",
default=1,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"-b",
"--batch-size",
# This should be 256, but that OOMs using the prototype.
default=64,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-na",
"--n_accumulation_steps",
default=1,
type=int,
metavar="N",
help="number of mini-batches to accumulate into an effective batch",
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
)
parser.add_argument(
"--wd",
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="SGD weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training. "
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="checkpoint",
help="path to save check points",
)
parser.add_argument(
"--data-root",
type=str,
default="../cifar10",
help="Where CIFAR10 is/will be stored",
)
parser.add_argument(
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
)
parser.add_argument(
"--optim",
type=str,
default="Adam",
help="Optimizer to use (Adam, RMSprop, SGD)",
)
args = parser.parse_args()
args.disable_dp = True
if args.disable_dp and args.n_accumulation_steps > 1:
raise ValueError("Virtual steps only works with enabled DP")
# The following few lines, enable stats gathering about the run
# 1. where the stats should be logged
stats.set_global_summary_writer(
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
)
# 2. enable stats
stats.add(
# stats about gradient norms aggregated for all layers
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
# stats about gradient norms per layer
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
# stats about clipping
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
# stats on training accuracy
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
# stats on validation accuracy
stats.Stat(stats.StatType.TEST, "accuracy"),
)
# The following lines enable stat gathering for the clipping process
# and set a default of per layer clipping for the Privacy Engine
clipping = {"clip_per_layer": False, "enable_stat": True}
if args.secure_rng:
assert False
try:
import torchcsprng as prng
except ImportError as e:
msg = (
"To use secure RNG, you must install the torchcsprng package! "
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
)
raise ImportError(msg) from e
generator = prng.create_random_device_generator("/dev/urandom")
else:
generator = None
augmentations = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
normalize = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
train_transform = transforms.Compose(
augmentations + normalize if args.disable_dp else normalize
)
test_transform = transforms.Compose(normalize)
train_dataset = CIFAR10(
root=args.data_root, train=True, download=True, transform=train_transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
generator=generator,
)
test_dataset = CIFAR10(
root=args.data_root, train=False, download=True, transform=test_transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
)
best_acc1 = 0
device = torch.device(args.device)
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
# model = CIFAR10Model()
model = model.to(device)
if args.optim == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optim == "RMSprop":
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not recognized. Please check spelling")
if not args.disable_dp:
privacy_engine = PrivacyEngine(
model,
batch_size=args.batch_size * args.n_accumulation_steps,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=args.sigma,
max_grad_norm=args.max_per_sample_grad_norm,
secure_rng=args.secure_rng,
**clipping,
)
privacy_engine.attach(optimizer)
for epoch in range(args.start_epoch, args.epochs + 1):
train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
is_best = top1_acc > best_acc1
best_acc1 = max(top1_acc, best_acc1)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": "ResNet18",
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"optimizer": optimizer.state_dict(),
},
is_best,
filename=args.checkpoint_file + ".tar",
)
if __name__ == "__main__":
main()

View File

@ -1,71 +0,0 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
import copy
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p.detach().requires_grad_() for p in orig_params)
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_set_nested_attr(mod, name.split("."), p)
def make_functional(model: nn.Module):
weights, descriptors = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors

View File

@ -2,7 +2,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional, grad_with_value, vmap
from functorch import make_functional, grad_and_value, vmap
# Adapted from http://willwhitney.com/parallel-training-jax.html
# GOAL: Demonstrate that it is possible to use eager-mode vmap
@ -58,7 +58,7 @@ def train_step_fn(weights, batch, targets, lr=0.2):
loss = loss_fn(output, targets)
return loss
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
# so we are going to re-implement SGD here.

View File

@ -141,7 +141,7 @@ def grad_and_value(f, argnums=0, has_aux=False):
aux = _undo_create_differentiable(aux, level)
_grad_decrement_nesting()
if has_aux:
return grad_input, output, aux
return grad_input, (output, aux)
return grad_input, output
return wrapper
@ -150,7 +150,9 @@ def grad(f, argnums=0, has_aux=False):
def wrapper(*args):
results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
if has_aux:
return results[0], results[2]
return results[0]
grad, (value, aux) = results
return grad, aux
grad, value = results
return grad
return wrapper