mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Quick attempt at hiding functional module init
Introduces a `functional_init` and `functional_init_with_buffers` that lets one initialize an ensemble of modules more easily than before. This was done in the sprit of make_functional: the API still looks awkward, especially when buffers are involved.
This commit is contained in:
@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import make_functional, grad_and_value, vmap
|
||||
from functorch import make_functional, grad_and_value, vmap, functional_init
|
||||
|
||||
# Adapted from http://willwhitney.com/parallel-training-jax.html
|
||||
# GOAL: Demonstrate that it is possible to use eager-mode vmap
|
||||
@ -84,10 +84,7 @@ step4()
|
||||
# Step 5: We're ready for multiple models. Let's define an init_fn
|
||||
# that, given a number of models, returns to us all of the weights.
|
||||
def init_fn(num_models):
|
||||
models = tuple(MLPClassifier() for _ in range(num_models))
|
||||
weights = tuple(make_functional(model)[0] for model in models)
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
weights, _, _ = functional_init(MLPClassifier, (num_models,))()
|
||||
return weights
|
||||
|
||||
# Step 6: Now, can we try multiple models at the same time?
|
||||
|
@ -4,6 +4,7 @@ from . import _C
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers, load_state
|
||||
from ._src.make_functional import functional_init, functional_init_with_buffers
|
||||
from ._src.python_key import wrap_key, PythonTensor, pythonkey_trace, hasPythonKey, removePythonKey, addPythonKey, make_fx, nnc_jit, make_nnc
|
||||
from ._src.nnc_compile import nnc_compile
|
||||
|
||||
|
@ -167,3 +167,49 @@ def make_functional_with_buffers(model: nn.Module):
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, buffers, fun, weight_descriptors, buf_descriptors
|
||||
|
||||
|
||||
def functional_init(model_class, ensemble_shape=(), device='cpu'):
|
||||
def wrapped(*args, **kwargs):
|
||||
if len(ensemble_shape) >= 2:
|
||||
raise ValueError('NYI: ensemble_shape with more than 1 element')
|
||||
if len(ensemble_shape) == 0:
|
||||
model = model_class(*args, **kwargs).to(device)
|
||||
return make_functional(model)
|
||||
num_models = ensemble_shape[0]
|
||||
if num_models <= 0:
|
||||
raise ValueError(f"num_models {num_models} should be > 0")
|
||||
# NB: Not very efficient, more of a POC
|
||||
models = tuple(model_class(*args, **kwargs).to(device)
|
||||
for _ in range(num_models))
|
||||
_, fn, names = make_functional(model_class(*args, **kwargs))
|
||||
weights = tuple(make_functional(model)[0] for model in models)
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
return weights, fn, names
|
||||
return wrapped
|
||||
|
||||
|
||||
def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
|
||||
def wrapped(*args, **kwargs):
|
||||
if len(ensemble_shape) >= 2:
|
||||
raise ValueError('NYI: ensemble_shape with more than 1 element')
|
||||
if len(ensemble_shape) == 0:
|
||||
model = model_class(*args, **kwargs).to(device)
|
||||
return make_functional(model)
|
||||
num_models = ensemble_shape[0]
|
||||
if num_models <= 0:
|
||||
raise ValueError(f"num_models {num_models} should be > 0")
|
||||
# NB: Not very efficient, more of a POC
|
||||
models = tuple(model_class(*args, **kwargs).to(device)
|
||||
for _ in range(num_models))
|
||||
_, _, fn, weight_names, buffer_names = \
|
||||
make_functional_with_buffers(model_class(*args, **kwargs))
|
||||
weights, buffers = zip(*tuple(make_functional_with_buffers(model)[:2]
|
||||
for model in models))
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
buffers = tuple(zip(*buffers))
|
||||
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
|
||||
return weights, buffers, fn, weight_names, buffer_names
|
||||
return wrapped
|
||||
|
@ -17,6 +17,7 @@ import functorch
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_and_value,
|
||||
make_functional, make_functional_with_buffers,
|
||||
functional_init, functional_init_with_buffers,
|
||||
)
|
||||
|
||||
# NB: numpy is a testing dependency!
|
||||
@ -374,6 +375,53 @@ class TestGradTransform(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
|
||||
result, = vjp_fn((v1, (v2, v3)))
|
||||
|
||||
def test_functional_init(self, device):
|
||||
class MLPClassifier(nn.Module):
|
||||
def __init__(self, hidden_dim=32, n_classes=2):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_classes = n_classes
|
||||
|
||||
self.fc1 = nn.Linear(2, self.hidden_dim)
|
||||
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = F.log_softmax(x, -1)
|
||||
return x
|
||||
|
||||
B = 10
|
||||
weights, fn, _ = functional_init(MLPClassifier, (B,))(32, 2)
|
||||
inputs = torch.randn(B, 7, 2)
|
||||
vmap(fn)(weights, (inputs,))
|
||||
|
||||
def test_functional_init_with_buffers(self, device):
|
||||
class MLPClassifier(nn.Module):
|
||||
def __init__(self, hidden_dim=32, n_classes=2):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_classes = n_classes
|
||||
|
||||
self.fc1 = nn.Linear(2, self.hidden_dim)
|
||||
self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True)
|
||||
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.bn(x)
|
||||
x = self.fc2(x)
|
||||
x = F.log_softmax(x, -1)
|
||||
return x
|
||||
|
||||
B = 10
|
||||
weights, buffers, fn, _, _ = \
|
||||
functional_init_with_buffers(MLPClassifier, [B])(32, 2)
|
||||
inputs = torch.randn(B, 7, 2)
|
||||
vmap(fn)(weights, buffers, (inputs,))
|
||||
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self, device):
|
||||
|
Reference in New Issue
Block a user