mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Run all correctness tests
This commit is contained in:
@ -6,17 +6,31 @@ import unittest
|
||||
import functools
|
||||
import itertools
|
||||
import warnings
|
||||
import math
|
||||
from typing import Callable, Type
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
skipCUDAIfNoMagma, onlyOnCPUAndCUDA
|
||||
import types
|
||||
from functools import partial
|
||||
|
||||
import functorch
|
||||
from functorch import grad, vjp, vmap, make_functional, jacrev, make_functional_with_buffers
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_with_value,
|
||||
make_functional, make_functional_with_buffers,
|
||||
)
|
||||
|
||||
# NB: numpy is a testing dependency!
|
||||
import numpy as np
|
||||
|
||||
USE_TORCHVISION = False
|
||||
try:
|
||||
import torchvision
|
||||
USE_TORCHVISION = True
|
||||
except:
|
||||
warnings.warn("Couldn't import torchvision. Some of our tests use it, try ",
|
||||
"to install it with commands from pytorch.org, post-fixed with ",
|
||||
"`--no-deps` to avoid overwriting the pytorch installation")
|
||||
|
||||
|
||||
class TestGradTransform(TestCase):
|
||||
def test_primitive(self, device):
|
||||
@ -615,6 +629,152 @@ class TestExamplesCorrectness(TestCase):
|
||||
|
||||
self.assertEqual(result_grads, expected_grads)
|
||||
|
||||
def test_ensemble_regression(self, device):
|
||||
def make_spirals(n_samples, noise_std=0., rotations=1.):
|
||||
ts = torch.linspace(0, 1, n_samples)
|
||||
rs = ts ** 0.5
|
||||
thetas = rs * rotations * 2 * math.pi
|
||||
signs = torch.randint(0, 2, (n_samples,)) * 2 - 1
|
||||
labels = (signs > 0).to(torch.long)
|
||||
|
||||
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std
|
||||
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std
|
||||
points = torch.stack([xs, ys], dim=1)
|
||||
return points.to(device), labels.to(device)
|
||||
|
||||
points, labels = make_spirals(100, noise_std=0.05)
|
||||
|
||||
class MLPClassifier(nn.Module):
|
||||
def __init__(self, hidden_dim=32, n_classes=2):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_classes = n_classes
|
||||
|
||||
self.fc1 = nn.Linear(2, self.hidden_dim)
|
||||
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = F.log_softmax(x, -1)
|
||||
return x
|
||||
|
||||
loss_fn = nn.NLLLoss()
|
||||
|
||||
weights, func_model, _ = make_functional(MLPClassifier().to(device))
|
||||
|
||||
def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
|
||||
def compute_loss(weights, batch, targets):
|
||||
output = func_model(weights, (batch,))
|
||||
loss = loss_fn(output, targets)
|
||||
return loss
|
||||
|
||||
if use_transform:
|
||||
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
|
||||
else:
|
||||
loss = compute_loss(weights, batch, targets)
|
||||
grad_weights = torch.autograd.grad(loss, weights)
|
||||
|
||||
new_weights = []
|
||||
with torch.no_grad():
|
||||
for grad_weight, weight in zip(grad_weights, weights):
|
||||
new_weights.append(weight - grad_weight * lr)
|
||||
# NB: return looks weird because torch.vmap must return Tensors
|
||||
return (loss, *new_weights)
|
||||
|
||||
def unpack(train_result):
|
||||
return train_result[0], train_result[1:]
|
||||
|
||||
def init_fn(num_models):
|
||||
models = tuple(MLPClassifier().to(device) for _ in range(num_models))
|
||||
weights = tuple(make_functional(model)[0] for model in models)
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
return weights
|
||||
|
||||
def slice_weights(batched_weights, index):
|
||||
return tuple(weight[index].detach().requires_grad_() for weight in batched_weights)
|
||||
|
||||
batched_weights = init_fn(num_models=2)
|
||||
parallel_train_step_fn = vmap(partial(train_step_fn, True), in_dims=(0, None, None))
|
||||
|
||||
result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))
|
||||
|
||||
loss0, weights0 = unpack(train_step_fn(False, slice_weights(batched_weights, 0), points, labels))
|
||||
loss1, weights1 = unpack(train_step_fn(False, slice_weights(batched_weights, 1), points, labels))
|
||||
expected_loss = torch.stack([loss0, loss1])
|
||||
expected_weights = tuple(torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1))
|
||||
|
||||
self.assertEqual(result_loss, expected_loss)
|
||||
self.assertEqual(result_weights, expected_weights)
|
||||
|
||||
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
||||
def test_resnet18_per_sample_grads(self, device):
|
||||
# Straight out of opacus
|
||||
def _replace_child(
|
||||
root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module]
|
||||
) -> None:
|
||||
# find the immediate parent
|
||||
parent = root
|
||||
nameList = child_name.split(".")
|
||||
for name in nameList[:-1]:
|
||||
parent = parent._modules[name]
|
||||
# set to identity
|
||||
parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]])
|
||||
|
||||
def replace_all_modules(
|
||||
root: nn.Module,
|
||||
target_class: Type[nn.Module],
|
||||
converter: Callable[[nn.Module], nn.Module],
|
||||
) -> nn.Module:
|
||||
# base case
|
||||
if isinstance(root, target_class):
|
||||
return converter(root)
|
||||
|
||||
for name, obj in root.named_modules():
|
||||
if isinstance(obj, target_class):
|
||||
_replace_child(root, name, converter)
|
||||
return root
|
||||
|
||||
def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module:
|
||||
return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True)
|
||||
|
||||
def convert_batchnorm_modules(
|
||||
model: nn.Module,
|
||||
converter: Callable[
|
||||
[nn.modules.batchnorm._BatchNorm], nn.Module
|
||||
] = _batchnorm_to_groupnorm,
|
||||
) -> nn.Module:
|
||||
return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter)
|
||||
|
||||
import torchvision.models as models
|
||||
model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
weights, func_model, descriptors = make_functional(model)
|
||||
|
||||
def compute_loss(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
output = func_model(weights, (images,))
|
||||
loss = criterion(output, targets)
|
||||
return loss
|
||||
|
||||
batch_size = 3
|
||||
images = torch.randn(batch_size, 3, 32, 32, device=device)
|
||||
targets = torch.randint(0, 10, (batch_size,), device=device)
|
||||
|
||||
result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets)
|
||||
|
||||
expected_grads = [
|
||||
torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]
|
||||
|
||||
self.assertEqual(result_grads, expected_grads)
|
||||
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestGradTransform,
|
||||
|
Reference in New Issue
Block a user