[functorch] Beef up make_functional docstring, update some examples

This commit is contained in:
Richard Zou
2021-05-04 07:01:42 -07:00
committed by Jon Janzen
parent 6ecf169a07
commit 15ab42ce7c
7 changed files with 83 additions and 85 deletions

View File

@ -61,7 +61,7 @@ pip install --user "git+https://github.com/zou3519/functorch.git"
```
Run a quick sanity check in python:
```
```py
>>> import torch
>>> from functorch import vmap
>>> x = torch.randn(3)
@ -91,7 +91,10 @@ Right now, we support the following transforms:
- `vmap`
Furthermore, we have some utilities for working with PyTorch modules.
- `make_functional_with_buffers`
- `make_functional(model)` takes a model and returns its weights and a function
version of the model that has no state.
- `make_functional_with_buffers(model)` takes a model and returns its weights
and buffers and a function version of the model that has no state.
### vmap

View File

@ -24,9 +24,9 @@ from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from make_functional import make_functional, load_weights
from functools import partial
from functorch import vmap, grad_and_value
from functorch import make_functional, load_state
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
torch.save(state, filename)
@ -116,9 +116,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
vmap(grads_loss_output, (None, 0, 0))(weights, images, target)
loss = sample_loss.mean()
# `load_weights` is the inverse operation of make_functional. We put
# `state` is the inverse operation of make_functional. We put
# things back into a model so that they're easier to manipulate
load_weights(model, descriptors, weights)
load_state(model, weights, descriptors)
for grad_sample, weight in zip(sample_grads, model.parameters()):
weight.grad_sample = grad_sample.detach()

View File

@ -1,71 +0,0 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
import copy
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p for p in orig_params)
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_set_nested_attr(mod, name.split("."), p)
def make_functional(model: nn.Module):
weights, descriptors = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors

View File

@ -66,19 +66,16 @@ def train_step_fn(weights, batch, targets, lr=0.2):
with torch.no_grad():
for grad_weight, weight in zip(grad_weights, weights):
new_weights.append(weight - grad_weight * lr)
# NB: return looks weird because torch.vmap must return Tensors
return (loss, *new_weights)
return loss, new_weights
def unpack(train_result):
return train_result[0], train_result[1:]
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
def step4():
global weights
for i in range(2000):
loss, weights = unpack(train_step_fn(weights, points, labels))
loss, weights = train_step_fn(weights, points, labels)
if i % 100 == 0:
print(loss)
@ -100,7 +97,7 @@ def step6():
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
batched_weights = init_fn(num_models=2)
for i in range(2000):
loss, batched_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
if i % 200 == 0:
print(loss)

View File

@ -111,7 +111,7 @@ def main():
# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
# and can be called with `fnet(params, buffers, args, kwargs)`
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
params, buffers, fnet, _, _ = make_functional_with_buffers(net)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.

View File

@ -3,10 +3,9 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
from ._src.make_functional import make_functional, make_functional_with_buffers, load_state
from ._src.python_key import wrap_key, WrapModule, PythonTensor, pythonkey_trace
# Monkeypatching lol
_old_cross_entropy = torch.nn.functional.cross_entropy

View File

@ -76,7 +76,53 @@ def load_buffers(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a
for name, p in zip(names, params):
_set_nested_attr(mod, name.split("."), p)
def load_state(
model: nn.Module,
weights: List[Tensor], weight_names: List[str],
buffers=(), buffer_names=()):
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
This is the inverse operation of `make_functional`.
"""
assert len(weight_names) == len(weights)
load_weights(model, weight_names, weights)
if len(buffers) > 0:
assert len(buffer_names) == len(buffers)
load_buffers(model, buffer_names, buffers)
return model
def make_functional(model: nn.Module):
"""make_functional(model) -> weights, func, weight_names
Given an nn.Module, make_functional extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, func, _ = make_functional(model)
func(weights, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, _, func = make_functional(model)
grad_weights = grad(func)(weights, (x,))
```
To put the state back into a model, use `load_state`.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
'make_functional_with_buffers(model) instead.')
weights, descriptors = extract_weights(model)
def fun(weights, data):
@ -87,6 +133,30 @@ def make_functional(model: nn.Module):
return weights, fun, descriptors
def make_functional_with_buffers(model: nn.Module):
"""make_functional_with_buffers(model) -> weights, buffers, func, weight_names, buffer_names
Given an nn.Module, make_functional_with_buffers extracts the state (weights and buffers)
and returns a functional version of the model, `func`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers(model)
func(weights, buffers, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers(model)
func(weights, buffers, (x,))
grad_weights = grad(func)(weights, buffers, (x,))
```
To put the state back into a model, use `load_state`.
"""
weights, weight_descriptors = extract_weights(model)
buffers, buf_descriptors = extract_buffers(model)