mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156081 Approved by: https://github.com/albanD ghstack dependencies: #156080
152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
import argparse
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
|
|
|
|
|
|
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
|
|
# tutorial on Model Ensembling with JAX by Will Whitney.
|
|
#
|
|
# The original code comes with the following citation:
|
|
# @misc{Whitney2021Parallelizing,
|
|
# author = {William F. Whitney},
|
|
# title = { {Parallelizing neural networks on one GPU with JAX} },
|
|
# year = {2021},
|
|
# url = {http://willwhitney.com/parallel-training-jax.html},
|
|
# }
|
|
|
|
# GOAL: Demonstrate that it is possible to use eager-mode vmap
|
|
# to parallelize training over models.
|
|
|
|
parser = argparse.ArgumentParser(description="Functorch Ensembled Models")
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cpu",
|
|
help="CPU or GPU ID for this process (default: 'cpu')",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
DEVICE = args.device
|
|
|
|
# Step 1: Make some spirals
|
|
|
|
|
|
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
|
|
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
|
|
rs = ts**0.5
|
|
thetas = rs * rotations * 2 * math.pi
|
|
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
|
|
labels = (signs > 0).to(torch.long).to(DEVICE)
|
|
|
|
xs = (
|
|
rs * signs * torch.cos(thetas)
|
|
+ torch.randn(n_samples, device=DEVICE) * noise_std
|
|
)
|
|
ys = (
|
|
rs * signs * torch.sin(thetas)
|
|
+ torch.randn(n_samples, device=DEVICE) * noise_std
|
|
)
|
|
points = torch.stack([xs, ys], dim=1)
|
|
return points, labels
|
|
|
|
|
|
points, labels = make_spirals(100, noise_std=0.05)
|
|
|
|
|
|
# Step 2: Define two-layer MLP and loss function
|
|
class MLPClassifier(nn.Module):
|
|
def __init__(self, hidden_dim=32, n_classes=2):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_dim
|
|
self.n_classes = n_classes
|
|
|
|
self.fc1 = nn.Linear(2, self.hidden_dim)
|
|
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = F.relu(x)
|
|
x = self.fc2(x)
|
|
x = F.log_softmax(x, -1)
|
|
return x
|
|
|
|
|
|
loss_fn = nn.NLLLoss()
|
|
model = MLPClassifier().to(DEVICE)
|
|
|
|
|
|
def train_step_fn(weights, batch, targets, lr=0.2):
|
|
def compute_loss(weights, batch, targets):
|
|
output = functional_call(model, weights, batch)
|
|
loss = loss_fn(output, targets)
|
|
return loss
|
|
|
|
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
|
|
|
|
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
|
|
# so we are going to re-implement SGD here.
|
|
new_weights = {}
|
|
with torch.no_grad():
|
|
for key in grad_weights:
|
|
new_weights[key] = weights[key] - grad_weights[key] * lr
|
|
|
|
return loss, new_weights
|
|
|
|
|
|
# Step 4: Let's verify this actually trains.
|
|
# We should see the loss decrease.
|
|
def step4():
|
|
global weights
|
|
for i in range(2000):
|
|
loss, weights = train_step_fn(dict(model.named_parameters()), points, labels)
|
|
if i % 100 == 0:
|
|
print(loss)
|
|
|
|
|
|
step4()
|
|
|
|
# Step 5: We're ready for multiple models. Let's define an init_fn
|
|
# that, given a number of models, returns to us all of the weights.
|
|
|
|
|
|
def init_fn(num_models):
|
|
models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
|
|
params, _ = stack_module_state(models)
|
|
return params
|
|
|
|
|
|
# Step 6: Now, can we try multiple models at the same time?
|
|
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
|
|
# on decreasing
|
|
|
|
|
|
def step6():
|
|
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
|
|
batched_weights = init_fn(num_models=2)
|
|
for i in range(2000):
|
|
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
|
|
if i % 200 == 0:
|
|
print(loss)
|
|
|
|
|
|
step6()
|
|
|
|
# Step 7: Now, the flaw with step 6 is that we were training on the same exact
|
|
# data. This can lead to all of the models in the ensemble overfitting in the
|
|
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
|
|
# applies is to randomly subset the data in a way that the models do not receive
|
|
# exactly the same data in each training step!
|
|
# Because the goal of this doc is to show that we can use eager-mode vmap to
|
|
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
|
|
|
|
# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html
|
|
# does, we used the following additional items that PyTorch does not have:
|
|
# 1. NN module functional API that turns a module into a (state, state_less_fn) pair
|
|
# 2. Functional optimizers
|
|
# 3. A "functional" grad API (that effectively wraps autograd.grad)
|
|
# 4. Composability between the functional grad API and torch.vmap.
|