mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
113 lines
3.3 KiB
Python
113 lines
3.3 KiB
Python
import math
|
|
import random
|
|
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
import matplotlib as mpl
|
|
mpl.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
|
|
def net(x, params):
|
|
x = F.linear(x, params[0], params[1])
|
|
x = F.relu(x)
|
|
|
|
x = F.linear(x, params[2], params[3])
|
|
x = F.relu(x)
|
|
|
|
x = F.linear(x, params[4], params[5])
|
|
return x
|
|
|
|
params = [
|
|
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
|
|
torch.Tensor(40).zero_().requires_grad_(),
|
|
|
|
torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
|
|
torch.Tensor(40).zero_().requires_grad_(),
|
|
|
|
torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(),
|
|
torch.Tensor(1).zero_().requires_grad_(),
|
|
]
|
|
|
|
opt = torch.optim.Adam(params, lr=1e-3)
|
|
alpha = 0.1
|
|
|
|
K = 20
|
|
losses = []
|
|
num_tasks = 4
|
|
def sample_tasks(outer_batch_size, inner_batch_size):
|
|
# Select amplitude and phase for the task
|
|
As = []
|
|
phases = []
|
|
for _ in range(outer_batch_size):
|
|
As.append(np.random.uniform(low=0.1, high=.5))
|
|
phases.append(np.random.uniform(low=0., high=np.pi))
|
|
def get_batch():
|
|
xs, ys = [], []
|
|
for A, phase in zip(As, phases):
|
|
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
|
y = A * np.sin(x + phase)
|
|
xs.append(x)
|
|
ys.append(y)
|
|
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
|
x1, y1 = get_batch()
|
|
x2, y2 = get_batch()
|
|
return x1, y1, x2, y2
|
|
|
|
for it in range(20000):
|
|
loss2 = 0.0
|
|
opt.zero_grad()
|
|
def get_loss_for_task(x1, y1, x2, y2):
|
|
f = net(x1, params)
|
|
loss = F.mse_loss(f, y1)
|
|
|
|
# create_graph=True because computing grads here is part of the forward pass.
|
|
# We want to differentiate through the SGD update steps and get higher order
|
|
# derivatives in the backward pass.
|
|
grads = torch.autograd.grad(loss, params, create_graph=True)
|
|
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
|
|
|
|
v_f = net(x2, new_params)
|
|
return F.mse_loss(v_f, y2)
|
|
|
|
task = sample_tasks(num_tasks, K)
|
|
inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)]
|
|
loss2 = sum(inner_losses)/len(inner_losses)
|
|
loss2.backward()
|
|
|
|
opt.step()
|
|
|
|
if it % 100 == 0:
|
|
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
|
losses.append(loss2)
|
|
|
|
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
|
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
|
|
|
|
t_x = torch.empty(4, 1).uniform_(-5, 5)
|
|
t_y = t_A*torch.sin(t_x + t_b)
|
|
|
|
opt.zero_grad()
|
|
|
|
t_params = params
|
|
for k in range(5):
|
|
t_f = net(t_x, t_params)
|
|
t_loss = F.l1_loss(t_f, t_y)
|
|
|
|
grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
|
|
t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]
|
|
|
|
|
|
test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
|
|
test_y = t_A*torch.sin(test_x + t_b)
|
|
|
|
test_f = net(test_x, t_params)
|
|
|
|
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
|
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
|
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
|
plt.legend()
|
|
plt.savefig('maml-sine.png')
|
|
plt.figure()
|
|
plt.plot(np.convolve(losses, [.05]*20))
|
|
plt.savefig('losses.png') |