mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Fix up the cifar10_transforms example
This commit is contained in:
@ -25,9 +25,8 @@ from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from make_functional import make_functional, load_weights
|
||||
from torch.eager_transforms import vmap, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
from functorch import vmap, grad_and_value
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
@ -104,7 +103,7 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
||||
# We want to extract some intermediate
|
||||
# values from the computation (i.e. the loss and output).
|
||||
#
|
||||
# To extract the loss, we use the `grad_with_value` API, that returns the
|
||||
# To extract the loss, we use the `grad_and_value` API, that returns the
|
||||
# gradient of the weights w.r.t. the loss and the loss.
|
||||
#
|
||||
# To extract the output, we use the `has_aux=True` flag.
|
||||
@ -112,13 +111,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
||||
# where the first is to be differentiated and the second "auxiliary value"
|
||||
# is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
|
||||
# the loss, and the auxiliary value.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
def packed(weights, images, target):
|
||||
grads, loss, output = grads_loss_output(weights, images, target)
|
||||
result = tuple([*grads, loss, output])
|
||||
return result
|
||||
result = vmap(partial(packed, weights))(images, target)
|
||||
sample_grads, sample_loss, output, = result[:-2], result[-2], result[-1]
|
||||
grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, (sample_loss, output) = \
|
||||
vmap(grads_loss_output, (None, 0, 0))(weights, images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
|
@ -1,438 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from functools import partial
|
||||
from make_functional import make_functional, load_weights
|
||||
|
||||
# NB: The following might not exist depending on what you're using
|
||||
from torch import vmap
|
||||
from functional_utils import grad, grad_with_value
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
|
||||
# TODO(rzou): does the group norm work correctly?
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
[weight.requires_grad_(False) for weight in weights]
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
target = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, target)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean(0)
|
||||
|
||||
# Step 2: Clip the per-sample-grads and sum them to form grads
|
||||
|
||||
# TODO(rzou): Right now we just sum the grads. Instead we need to clip them.
|
||||
for sample_grad, weight in zip(grads, weights):
|
||||
weight_grad = sample_grad.sum(0)
|
||||
weight.grad = weight_grad
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
[weight.requires_grad_(True) for weight in weights]
|
||||
load_weights(model, descriptors, weights, as_params=True)
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
if not args.disable_dp:
|
||||
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
|
||||
args.delta
|
||||
)
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
default=256,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 256), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,491 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch import vmap
|
||||
from make_functional import make_functional, load_weights
|
||||
from functional_utils import grad, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
class CudaMemoryLeakCheck():
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
# initialize context & RNG to prevenikkt false positive detections
|
||||
# when the test is the first to initialize those
|
||||
from torch.testing._internal.common_cuda import initialize_cuda_context_rng
|
||||
initialize_cuda_context_rng()
|
||||
|
||||
@staticmethod
|
||||
def get_cuda_memory_usage():
|
||||
# we don't need CUDA synchronize because the statistics are not tracked at
|
||||
# actual freeing, but at when marking the block as free.
|
||||
num_devices = torch.cuda.device_count()
|
||||
import gc
|
||||
gc.collect()
|
||||
return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
|
||||
|
||||
def __enter__(self):
|
||||
self.befores = self.get_cuda_memory_usage()
|
||||
|
||||
def __exit__(self, exec_type, exec_value, traceback):
|
||||
# Don't check for leaks if an exception was thrown
|
||||
if exec_type is not None:
|
||||
return
|
||||
|
||||
afters = self.get_cuda_memory_usage()
|
||||
|
||||
for i, (before, after) in enumerate(zip(self.befores, afters)):
|
||||
if after - before == 0:
|
||||
continue
|
||||
raise RuntimeError(f'{self.name} leaked {after-before} bytes')
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
return grads
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
use_prototype = False
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# grad_with_value(f) returns a function that returns (1) the grad and
|
||||
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second is not to be
|
||||
# differentiated and further adds a 3rd output.
|
||||
#
|
||||
# We need to use `grad_with_value(..., has_aux=True)` because we do
|
||||
# some analyses on the returned loss and output.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
sample_grads, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
load_weights(model, descriptors, weights)
|
||||
|
||||
for weight_grad, weight in zip(grads, model.parameters()):
|
||||
weight.grad = weight_grad.detach()
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,457 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch import vmap
|
||||
from make_functional import make_functional, load_weights
|
||||
from functional_utils import grad, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
return grads
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
use_prototype = False
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# grad_with_value(f) returns a function that returns (1) the grad and
|
||||
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second is not to be
|
||||
# differentiated and further adds a 3rd output.
|
||||
#
|
||||
# We need to use `grad_with_value(..., has_aux=True)` because we do
|
||||
# some analyses on the returned loss and output.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
sample_grads, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
load_weights(model, descriptors, weights)
|
||||
|
||||
for weight_grad, weight in zip(grads, model.parameters()):
|
||||
weight.grad = weight_grad.detach()
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,456 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Runs CIFAR10 training with differential privacy.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils import stats
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch import vmap
|
||||
from make_functional import make_functional, load_weights
|
||||
from functional_utils import grad, grad_with_value
|
||||
from functools import partial
|
||||
# from resnet import resnet18
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, "model_best.pth.tar")
|
||||
|
||||
|
||||
def accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms
|
||||
|
||||
def clip_and_accumulate_and_add_noise(sample_grads, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
# step 0: compute the norms
|
||||
sample_norms = compute_norms(sample_grads)
|
||||
|
||||
# step 1: compute clipping factors
|
||||
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
return grads
|
||||
|
||||
def train(args, model, train_loader, optimizer, epoch, device):
|
||||
model.train()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Step 1: compute per-sample-grads
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss_and_output(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss, output.squeeze(0)
|
||||
|
||||
# grad_with_value(f) returns a function that returns (1) the grad and
|
||||
# (2) the output. `has_aux=True` means that `f` returns a tuple of two values,
|
||||
# where the first is to be differentiated and the second is not to be
|
||||
# differentiated and further adds a 3rd output.
|
||||
#
|
||||
# We need to use `grad_with_value(..., has_aux=True)` because we do
|
||||
# some analyses on the returned loss and output.
|
||||
grads_loss_output = grad_with_value(compute_loss_and_output, has_aux=True)
|
||||
sample_grads, sample_loss, output = vmap(partial(grads_loss_output, weights))(images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
grads = clip_and_accumulate_and_add_noise(
|
||||
sample_grads, args.max_per_sample_grad_norm, args.sigma)
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that we can directly apply optimizers.
|
||||
# TODO(rzou): this might not be necessary, optimizers just take
|
||||
# the params straight up.
|
||||
load_weights(model, descriptors, weights)
|
||||
|
||||
for weight_grad, weight in zip(grads, model.parameters()):
|
||||
weight.grad = weight_grad.detach()
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
losses.append(loss.item())
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
top1_acc.append(acc1)
|
||||
stats.update(stats.StatType.TRAIN, acc1=acc1)
|
||||
|
||||
# make sure we take a step after processing the last mini-batch in the
|
||||
# epoch to ensure we start the next epoch with a clean state
|
||||
if ((i + 1) % args.n_accumulation_steps == 0) or ((i + 1) == len(train_loader)):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
optimizer.virtual_step()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print(
|
||||
f"\tTrain Epoch: {epoch} \t"
|
||||
f"Loss: {np.mean(losses):.6f} "
|
||||
f"Acc@1: {np.mean(top1_acc):.6f} "
|
||||
)
|
||||
|
||||
def test(args, model, test_loader, device):
|
||||
model.eval()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
top1_acc = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, target in tqdm(test_loader):
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
acc1 = accuracy(preds, labels)
|
||||
|
||||
losses.append(loss.item())
|
||||
top1_acc.append(acc1)
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
stats.update(stats.StatType.TEST, acc1=top1_avg)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=90,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="manual epoch number (useful on restarts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
# This should be 256, but that OOMs using the prototype.
|
||||
default=64,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size (default: 64), this is the total "
|
||||
"batch size of all GPUs on the current node when "
|
||||
"using Data Parallel or Distributed Data Parallel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-na",
|
||||
"--n_accumulation_steps",
|
||||
default=1,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of mini-batches to accumulate into an effective batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
"--learning-rate",
|
||||
default=0.001,
|
||||
type=float,
|
||||
metavar="LR",
|
||||
help="initial learning rate",
|
||||
dest="lr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=5e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="SGD weight decay (default: 1e-4)",
|
||||
dest="weight_decay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--print-freq",
|
||||
default=10,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="print frequency (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
default="",
|
||||
type=str,
|
||||
metavar="PATH",
|
||||
help="path to latest checkpoint (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--evaluate",
|
||||
dest="evaluate",
|
||||
action="store_true",
|
||||
help="evaluate model on validation set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=None, type=int, help="seed for initializing training. "
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="GPU ID for this process (default: 'cuda')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="S",
|
||||
help="Noise multiplier (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--max-per-sample-grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="C",
|
||||
help="Clip per-sample gradients to this norm (default 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-dp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable privacy training and just train with vanilla SGD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure-rng",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
metavar="D",
|
||||
help="Target delta (default: 1e-5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="path to save check points",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
type=str,
|
||||
default="../cifar10",
|
||||
help="Where CIFAR10 is/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim",
|
||||
type=str,
|
||||
default="Adam",
|
||||
help="Optimizer to use (Adam, RMSprop, SGD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.disable_dp = True
|
||||
|
||||
if args.disable_dp and args.n_accumulation_steps > 1:
|
||||
raise ValueError("Virtual steps only works with enabled DP")
|
||||
|
||||
# The following few lines, enable stats gathering about the run
|
||||
# 1. where the stats should be logged
|
||||
stats.set_global_summary_writer(
|
||||
tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir))
|
||||
)
|
||||
# 2. enable stats
|
||||
stats.add(
|
||||
# stats about gradient norms aggregated for all layers
|
||||
stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
|
||||
# stats about gradient norms per layer
|
||||
stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
|
||||
# stats about clipping
|
||||
stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
|
||||
# stats on training accuracy
|
||||
stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
|
||||
# stats on validation accuracy
|
||||
stats.Stat(stats.StatType.TEST, "accuracy"),
|
||||
)
|
||||
|
||||
# The following lines enable stat gathering for the clipping process
|
||||
# and set a default of per layer clipping for the Privacy Engine
|
||||
clipping = {"clip_per_layer": False, "enable_stat": True}
|
||||
|
||||
if args.secure_rng:
|
||||
assert False
|
||||
try:
|
||||
import torchcsprng as prng
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use secure RNG, you must install the torchcsprng package! "
|
||||
"Check out the instructions here: https://github.com/pytorch/csprng#installation"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
generator = prng.create_random_device_generator("/dev/urandom")
|
||||
|
||||
else:
|
||||
generator = None
|
||||
|
||||
augmentations = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
normalize = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
]
|
||||
train_transform = transforms.Compose(
|
||||
augmentations + normalize if args.disable_dp else normalize
|
||||
)
|
||||
|
||||
test_transform = transforms.Compose(normalize)
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
root=args.data_root, train=True, download=True, transform=train_transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=args.data_root, train=False, download=True, transform=test_transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
best_acc1 = 0
|
||||
device = torch.device(args.device)
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
# model = CIFAR10Model()
|
||||
model = model.to(device)
|
||||
|
||||
if args.optim == "SGD":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
elif args.optim == "RMSprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
|
||||
elif args.optim == "Adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
else:
|
||||
raise NotImplementedError("Optimizer not recognized. Please check spelling")
|
||||
|
||||
if not args.disable_dp:
|
||||
privacy_engine = PrivacyEngine(
|
||||
model,
|
||||
batch_size=args.batch_size * args.n_accumulation_steps,
|
||||
sample_size=len(train_dataset),
|
||||
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
|
||||
noise_multiplier=args.sigma,
|
||||
max_grad_norm=args.max_per_sample_grad_norm,
|
||||
secure_rng=args.secure_rng,
|
||||
**clipping,
|
||||
)
|
||||
privacy_engine.attach(optimizer)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = top1_acc > best_acc1
|
||||
best_acc1 = max(top1_acc, best_acc1)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"arch": "ResNet18",
|
||||
"state_dict": model.state_dict(),
|
||||
"best_acc1": best_acc1,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
},
|
||||
is_best,
|
||||
filename=args.checkpoint_file + ".tar",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,71 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'])
|
||||
"""
|
||||
if len(names) == 1:
|
||||
delattr(obj, names[0])
|
||||
else:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
||||
"""
|
||||
if len(names) == 1:
|
||||
setattr(obj, names[0], value)
|
||||
else:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
The weights must be re-loaded with `load_weights` before the model
|
||||
can be used again.
|
||||
Note that this function modifies the model in place and after this
|
||||
call, mod.parameters() will be empty.
|
||||
"""
|
||||
orig_params = tuple(mod.parameters())
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_parameters()):
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
names.append(name)
|
||||
|
||||
# Make params regular Tensors instead of nn.Parameter
|
||||
params = tuple(p.detach().requires_grad_() for p in orig_params)
|
||||
return params, names
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
as Tensors. This means that mod.parameters() will still be empty after this call.
|
||||
"""
|
||||
for name, p in zip(names, params):
|
||||
if as_params:
|
||||
p = nn.Parameter(p)
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
mutable_model = copy.deepcopy(model)
|
||||
load_weights(mutable_model, descriptors, weights)
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, fun, descriptors
|
@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import make_functional, grad_with_value, vmap
|
||||
from functorch import make_functional, grad_and_value, vmap
|
||||
|
||||
# Adapted from http://willwhitney.com/parallel-training-jax.html
|
||||
# GOAL: Demonstrate that it is possible to use eager-mode vmap
|
||||
@ -58,7 +58,7 @@ def train_step_fn(weights, batch, targets, lr=0.2):
|
||||
loss = loss_fn(output, targets)
|
||||
return loss
|
||||
|
||||
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
|
||||
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
|
||||
|
||||
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
|
||||
# so we are going to re-implement SGD here.
|
||||
|
@ -141,7 +141,7 @@ def grad_and_value(f, argnums=0, has_aux=False):
|
||||
aux = _undo_create_differentiable(aux, level)
|
||||
_grad_decrement_nesting()
|
||||
if has_aux:
|
||||
return grad_input, output, aux
|
||||
return grad_input, (output, aux)
|
||||
return grad_input, output
|
||||
return wrapper
|
||||
|
||||
@ -150,7 +150,9 @@ def grad(f, argnums=0, has_aux=False):
|
||||
def wrapper(*args):
|
||||
results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
|
||||
if has_aux:
|
||||
return results[0], results[2]
|
||||
return results[0]
|
||||
grad, (value, aux) = results
|
||||
return grad, aux
|
||||
grad, value = results
|
||||
return grad
|
||||
return wrapper
|
||||
|
||||
|
Reference in New Issue
Block a user