diff --git a/functorch/.gitignore b/functorch/.gitignore index 9f5c3b244949..90ca585c3bc3 100644 --- a/functorch/.gitignore +++ b/functorch/.gitignore @@ -3,3 +3,5 @@ dist/ functorch.egg-info/ *__pycache__* functorch/version.py +functorch/_C.so +.gdbinit diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py index 2883214fefb9..56eecfb06728 100644 --- a/functorch/test/test_eager_transforms.py +++ b/functorch/test/test_eager_transforms.py @@ -12,7 +12,10 @@ import types from functools import partial import functorch -from functorch import grad, vjp, vmap, make_functional, jacrev +from functorch import grad, vjp, vmap, make_functional, jacrev, make_functional_with_buffers + +# NB: numpy is a testing dependency! +import numpy as np class TestGradTransform(TestCase): @@ -451,6 +454,168 @@ class TestComposability(TestCase): y = vjp_fn(x)[0] # Honestly IDK what the result here is... but at least it runs + +class TestExamplesCorrectness(TestCase): + def test_maml_regression(self, device): + class ThreeLayerNet(nn.Module): + def __init__(self): + super(ThreeLayerNet, self).__init__() + self.fc1 = nn.Linear(1, 40) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(40, 40) + self.relu2 = nn.ReLU() + self.fc3 = nn.Linear(40, 1) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + return x + + # The prototype doesn't like F.mse_loss. + def mse_loss(x, y): + return torch.mean((x - y) ** 2) + + params, net, _ = make_functional(ThreeLayerNet().to(device)) + K = 20 + losses = [] + num_tasks = 4 + alpha = 0.1 + + def sample_tasks(outer_batch_size, inner_batch_size): + # Select amplitude and phase for the task + As = [] + phases = [] + for _ in range(outer_batch_size): + As.append(np.random.uniform(low=0.1, high=.5)) + phases.append(np.random.uniform(low=0., high=np.pi)) + def get_batch(): + xs, ys = [], [] + for A, phase in zip(As, phases): + x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + y = A * np.sin(x + phase) + xs.append(x) + ys.append(y) + return torch.tensor(xs, dtype=torch.float, device=device), \ + torch.tensor(ys, dtype=torch.float, device=device) + x1, y1 = get_batch() + x2, y2 = get_batch() + return x1, y1, x2, y2 + + def get_loss_for_task(use_transform, x1, y1, x2, y2): + def inner_loss(params, x1, y1): + f = net(params, (x1,)) + loss = mse_loss(f, y1) + return loss + + if use_transform: + grads = grad(inner_loss)(params, x1, y1) + else: + loss = inner_loss(params, x1, y1) + grads = torch.autograd.grad(loss, params, create_graph=True) + new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))] + + v_f = net(new_params, (x2,)) + return mse_loss(v_f, y2) + + task = sample_tasks(num_tasks, K) + + # Compute with vmap+grad + inner_losses = vmap(partial(get_loss_for_task, True))\ + (task[0], task[1], task[2], task[3]) + loss2 = sum(inner_losses)/len(inner_losses) + result_grads = torch.autograd.grad(loss2, params) + + # Compute without vmap+grad + inner_losses = [ + get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i]) + for i in range(num_tasks) + ] + loss2 = sum(inner_losses)/len(inner_losses) + expected_grads = torch.autograd.grad(loss2, params) + + self.assertEqual(result_grads, expected_grads) + + def test_maml_omniglot(self, device): + # TODO: there appears to be precision issues for float32 + dtype = torch.double + + # TODO: The prototype doesn't support in-place relu (and some other + # in-place operations. That can be fixed.) + inplace_relu = False + n_way = 5 + n_inner_iter = 2 + num_tasks = 2 + class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + Flatten(), + nn.Linear(64, n_way)).to(device).to(dtype) + + params, buffers, fnet, _, _, = make_functional_with_buffers(net) + net = (params, buffers, fnet) + + def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry): + params, buffers, fnet = net + querysz = x_qry.size(0) + + def compute_loss(new_params, buffers, x, y): + logits = fnet(new_params, buffers, (x,)) + loss = F.cross_entropy(logits, y) + return loss + + new_params = params + for _ in range(n_inner_iter): + if use_transform: + grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) + else: + res = compute_loss(new_params, buffers, x_spt, y_spt) + grads = torch.autograd.grad(res, new_params, create_graph=True) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + qry_logits = fnet(new_params, buffers, (x_qry,)) + qry_loss = F.cross_entropy(qry_logits, y_qry) + qry_acc = (qry_logits.argmax( + dim=1) == y_qry).sum() / querysz + + return qry_loss, qry_acc + + # Get some sample inputs... + x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device) + y_spt = torch.randint(0, 5, (num_tasks, 25), device=device) + x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype,device=device) + y_qry = torch.randint(0, 5, (num_tasks, 75), device=device) + + # compute with vmap + grad + compute_loss = partial(loss_for_task, net, n_inner_iter, True) + qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry) + result_grads = torch.autograd.grad(qry_losses.sum(), params) + + # compute without vmap + grad + compute_loss = partial(loss_for_task, net, n_inner_iter, False) + losses = [compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0] + for i in range(num_tasks)] + expected_grads = torch.autograd.grad(sum(losses), params) + + self.assertEqual(result_grads, expected_grads) + + instantiate_device_type_tests( TestGradTransform, globals(), @@ -471,6 +636,12 @@ instantiate_device_type_tests( globals(), None, ) +instantiate_device_type_tests( + TestExamplesCorrectness, + globals(), + None, +) + if __name__ == '__main__': diff --git a/functorch/version.txt b/functorch/version.txt new file mode 100644 index 000000000000..8ea4f48f6ca5 --- /dev/null +++ b/functorch/version.txt @@ -0,0 +1 @@ +0.0.1a0