import argparse import math import torch import torch.nn as nn import torch.nn.functional as F from torch.func import functional_call, grad_and_value, stack_module_state, vmap # Adapted from http://willwhitney.com/parallel-training-jax.html , which is a # tutorial on Model Ensembling with JAX by Will Whitney. # # The original code comes with the following citation: # @misc{Whitney2021Parallelizing, # author = {William F. Whitney}, # title = { {Parallelizing neural networks on one GPU with JAX} }, # year = {2021}, # url = {http://willwhitney.com/parallel-training-jax.html}, # } # GOAL: Demonstrate that it is possible to use eager-mode vmap # to parallelize training over models. parser = argparse.ArgumentParser(description="Functorch Ensembled Models") parser.add_argument( "--device", type=str, default="cpu", help="CPU or GPU ID for this process (default: 'cpu')", ) args = parser.parse_args() DEVICE = args.device # Step 1: Make some spirals def make_spirals(n_samples, noise_std=0.0, rotations=1.0): ts = torch.linspace(0, 1, n_samples, device=DEVICE) rs = ts**0.5 thetas = rs * rotations * 2 * math.pi signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 labels = (signs > 0).to(torch.long).to(DEVICE) xs = ( rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std ) ys = ( rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std ) points = torch.stack([xs, ys], dim=1) return points, labels points, labels = make_spirals(100, noise_std=0.05) # Step 2: Define two-layer MLP and loss function class MLPClassifier(nn.Module): def __init__(self, hidden_dim=32, n_classes=2): super().__init__() self.hidden_dim = hidden_dim self.n_classes = n_classes self.fc1 = nn.Linear(2, self.hidden_dim) self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.log_softmax(x, -1) return x loss_fn = nn.NLLLoss() model = MLPClassifier().to(DEVICE) def train_step_fn(weights, batch, targets, lr=0.2): def compute_loss(weights, batch, targets): output = functional_call(model, weights, batch) loss = loss_fn(output, targets) return loss grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) # so we are going to re-implement SGD here. new_weights = {} with torch.no_grad(): for key in grad_weights: new_weights[key] = weights[key] - grad_weights[key] * lr return loss, new_weights # Step 4: Let's verify this actually trains. # We should see the loss decrease. def step4(): global weights for i in range(2000): loss, weights = train_step_fn(dict(model.named_parameters()), points, labels) if i % 100 == 0: print(loss) step4() # Step 5: We're ready for multiple models. Let's define an init_fn # that, given a number of models, returns to us all of the weights. def init_fn(num_models): models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] params, _ = stack_module_state(models) return params # Step 6: Now, can we try multiple models at the same time? # The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps # on decreasing def step6(): parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) batched_weights = init_fn(num_models=2) for i in range(2000): loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) if i % 200 == 0: print(loss) step6() # Step 7: Now, the flaw with step 6 is that we were training on the same exact # data. This can lead to all of the models in the ensemble overfitting in the # same way. The solution that http://willwhitney.com/parallel-training-jax.html # applies is to randomly subset the data in a way that the models do not receive # exactly the same data in each training step! # Because the goal of this doc is to show that we can use eager-mode vmap to # achieve similar things as JAX, the rest of this is left as an exercise to the reader. # In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html # does, we used the following additional items that PyTorch does not have: # 1. NN module functional API that turns a module into a (state, state_less_fn) pair # 2. Functional optimizers # 3. A "functional" grad API (that effectively wraps autograd.grad) # 4. Composability between the functional grad API and torch.vmap.