[functorch] Quick attempt at hiding functional module init

Introduces a `functional_init` and `functional_init_with_buffers` that
lets one initialize an ensemble of modules more easily than before. This
was done in the sprit of make_functional: the API still looks awkward,
especially when buffers are involved.
This commit is contained in:
Richard Zou
2021-05-17 14:37:33 -07:00
committed by Jon Janzen
parent 7f344c5a0b
commit a7f406ce58
4 changed files with 97 additions and 5 deletions

View File

@ -2,7 +2,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional, grad_and_value, vmap
from functorch import make_functional, grad_and_value, vmap, functional_init
# Adapted from http://willwhitney.com/parallel-training-jax.html
# GOAL: Demonstrate that it is possible to use eager-mode vmap
@ -84,10 +84,7 @@ step4()
# Step 5: We're ready for multiple models. Let's define an init_fn
# that, given a number of models, returns to us all of the weights.
def init_fn(num_models):
models = tuple(MLPClassifier() for _ in range(num_models))
weights = tuple(make_functional(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
weights, _, _ = functional_init(MLPClassifier, (num_models,))()
return weights
# Step 6: Now, can we try multiple models at the same time?

View File

@ -4,6 +4,7 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers, load_state
from ._src.make_functional import functional_init, functional_init_with_buffers
from ._src.python_key import wrap_key, PythonTensor, pythonkey_trace, hasPythonKey, removePythonKey, addPythonKey, make_fx, nnc_jit, make_nnc
from ._src.nnc_compile import nnc_compile

View File

@ -167,3 +167,49 @@ def make_functional_with_buffers(model: nn.Module):
return mutable_model(*data)
return weights, buffers, fun, weight_descriptors, buf_descriptors
def functional_init(model_class, ensemble_shape=(), device='cpu'):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError('NYI: ensemble_shape with more than 1 element')
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional(model)
num_models = ensemble_shape[0]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(model_class(*args, **kwargs).to(device)
for _ in range(num_models))
_, fn, names = make_functional(model_class(*args, **kwargs))
weights = tuple(make_functional(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights, fn, names
return wrapped
def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError('NYI: ensemble_shape with more than 1 element')
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional(model)
num_models = ensemble_shape[0]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(model_class(*args, **kwargs).to(device)
for _ in range(num_models))
_, _, fn, weight_names, buffer_names = \
make_functional_with_buffers(model_class(*args, **kwargs))
weights, buffers = zip(*tuple(make_functional_with_buffers(model)[:2]
for model in models))
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
buffers = tuple(zip(*buffers))
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
return weights, buffers, fn, weight_names, buffer_names
return wrapped

View File

@ -17,6 +17,7 @@ import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_and_value,
make_functional, make_functional_with_buffers,
functional_init, functional_init_with_buffers,
)
# NB: numpy is a testing dependency!
@ -374,6 +375,53 @@ class TestGradTransform(TestCase):
with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
result, = vjp_fn((v1, (v2, v3)))
def test_functional_init(self, device):
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
B = 10
weights, fn, _ = functional_init(MLPClassifier, (B,))(32, 2)
inputs = torch.randn(B, 7, 2)
vmap(fn)(weights, (inputs,))
def test_functional_init_with_buffers(self, device):
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.bn(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
B = 10
weights, buffers, fn, _, _ = \
functional_init_with_buffers(MLPClassifier, [B])(32, 2)
inputs = torch.randn(B, 7, 2)
vmap(fn)(weights, buffers, (inputs,))
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):