[functorch] testing, version.txt

This commit is contained in:
Richard Zou
2021-04-27 13:09:12 -07:00
committed by Jon Janzen
parent 8277d74e42
commit 20fac9da6e
3 changed files with 175 additions and 1 deletions

View File

@ -3,3 +3,5 @@ dist/
functorch.egg-info/
*__pycache__*
functorch/version.py
functorch/_C.so
.gdbinit

View File

@ -12,7 +12,10 @@ import types
from functools import partial
import functorch
from functorch import grad, vjp, vmap, make_functional, jacrev
from functorch import grad, vjp, vmap, make_functional, jacrev, make_functional_with_buffers
# NB: numpy is a testing dependency!
import numpy as np
class TestGradTransform(TestCase):
@ -451,6 +454,168 @@ class TestComposability(TestCase):
y = vjp_fn(x)[0]
# Honestly IDK what the result here is... but at least it runs
class TestExamplesCorrectness(TestCase):
def test_maml_regression(self, device):
class ThreeLayerNet(nn.Module):
def __init__(self):
super(ThreeLayerNet, self).__init__()
self.fc1 = nn.Linear(1, 40)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(40, 40)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(40, 1)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
# The prototype doesn't like F.mse_loss.
def mse_loss(x, y):
return torch.mean((x - y) ** 2)
params, net, _ = make_functional(ThreeLayerNet().to(device))
K = 20
losses = []
num_tasks = 4
alpha = 0.1
def sample_tasks(outer_batch_size, inner_batch_size):
# Select amplitude and phase for the task
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float, device=device), \
torch.tensor(ys, dtype=torch.float, device=device)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
def get_loss_for_task(use_transform, x1, y1, x2, y2):
def inner_loss(params, x1, y1):
f = net(params, (x1,))
loss = mse_loss(f, y1)
return loss
if use_transform:
grads = grad(inner_loss)(params, x1, y1)
else:
loss = inner_loss(params, x1, y1)
grads = torch.autograd.grad(loss, params, create_graph=True)
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
v_f = net(new_params, (x2,))
return mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
# Compute with vmap+grad
inner_losses = vmap(partial(get_loss_for_task, True))\
(task[0], task[1], task[2], task[3])
loss2 = sum(inner_losses)/len(inner_losses)
result_grads = torch.autograd.grad(loss2, params)
# Compute without vmap+grad
inner_losses = [
get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i])
for i in range(num_tasks)
]
loss2 = sum(inner_losses)/len(inner_losses)
expected_grads = torch.autograd.grad(loss2, params)
self.assertEqual(result_grads, expected_grads)
def test_maml_omniglot(self, device):
# TODO: there appears to be precision issues for float32
dtype = torch.double
# TODO: The prototype doesn't support in-place relu (and some other
# in-place operations. That can be fixed.)
inplace_relu = False
n_way = 5
n_inner_iter = 2
num_tasks = 2
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, n_way)).to(device).to(dtype)
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
net = (params, buffers, fnet)
def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
params, buffers, fnet = net
querysz = x_qry.size(0)
def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, (x,))
loss = F.cross_entropy(logits, y)
return loss
new_params = params
for _ in range(n_inner_iter):
if use_transform:
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
else:
res = compute_loss(new_params, buffers, x_spt, y_spt)
grads = torch.autograd.grad(res, new_params, create_graph=True)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]
qry_logits = fnet(new_params, buffers, (x_qry,))
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(
dim=1) == y_qry).sum() / querysz
return qry_loss, qry_acc
# Get some sample inputs...
x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device)
y_spt = torch.randint(0, 5, (num_tasks, 25), device=device)
x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype,device=device)
y_qry = torch.randint(0, 5, (num_tasks, 75), device=device)
# compute with vmap + grad
compute_loss = partial(loss_for_task, net, n_inner_iter, True)
qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry)
result_grads = torch.autograd.grad(qry_losses.sum(), params)
# compute without vmap + grad
compute_loss = partial(loss_for_task, net, n_inner_iter, False)
losses = [compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0]
for i in range(num_tasks)]
expected_grads = torch.autograd.grad(sum(losses), params)
self.assertEqual(result_grads, expected_grads)
instantiate_device_type_tests(
TestGradTransform,
globals(),
@ -471,6 +636,12 @@ instantiate_device_type_tests(
globals(),
None,
)
instantiate_device_type_tests(
TestExamplesCorrectness,
globals(),
None,
)
if __name__ == '__main__':

1
functorch/version.txt Normal file
View File

@ -0,0 +1 @@
0.0.1a0