mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Beef up make_functional docstring, update some examples
This commit is contained in:
@ -61,7 +61,7 @@ pip install --user "git+https://github.com/zou3519/functorch.git"
|
||||
```
|
||||
|
||||
Run a quick sanity check in python:
|
||||
```
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from functorch import vmap
|
||||
>>> x = torch.randn(3)
|
||||
@ -91,7 +91,10 @@ Right now, we support the following transforms:
|
||||
- `vmap`
|
||||
|
||||
Furthermore, we have some utilities for working with PyTorch modules.
|
||||
- `make_functional_with_buffers`
|
||||
- `make_functional(model)` takes a model and returns its weights and a function
|
||||
version of the model that has no state.
|
||||
- `make_functional_with_buffers(model)` takes a model and returns its weights
|
||||
and buffers and a function version of the model that has no state.
|
||||
|
||||
### vmap
|
||||
|
||||
|
@ -24,9 +24,9 @@ from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from make_functional import make_functional, load_weights
|
||||
from functools import partial
|
||||
from functorch import vmap, grad_and_value
|
||||
from functorch import make_functional, load_state
|
||||
|
||||
def save_checkpoint(state, is_best, filename="checkpoint.tar"):
|
||||
torch.save(state, filename)
|
||||
@ -116,9 +116,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
||||
vmap(grads_loss_output, (None, 0, 0))(weights, images, target)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
# `load_weights` is the inverse operation of make_functional. We put
|
||||
# `state` is the inverse operation of make_functional. We put
|
||||
# things back into a model so that they're easier to manipulate
|
||||
load_weights(model, descriptors, weights)
|
||||
load_state(model, weights, descriptors)
|
||||
for grad_sample, weight in zip(sample_grads, model.parameters()):
|
||||
weight.grad_sample = grad_sample.detach()
|
||||
|
||||
|
@ -1,71 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'])
|
||||
"""
|
||||
if len(names) == 1:
|
||||
delattr(obj, names[0])
|
||||
else:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
||||
"""
|
||||
if len(names) == 1:
|
||||
setattr(obj, names[0], value)
|
||||
else:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
The weights must be re-loaded with `load_weights` before the model
|
||||
can be used again.
|
||||
Note that this function modifies the model in place and after this
|
||||
call, mod.parameters() will be empty.
|
||||
"""
|
||||
orig_params = tuple(mod.parameters())
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_parameters()):
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
names.append(name)
|
||||
|
||||
# Make params regular Tensors instead of nn.Parameter
|
||||
params = tuple(p for p in orig_params)
|
||||
return params, names
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
as Tensors. This means that mod.parameters() will still be empty after this call.
|
||||
"""
|
||||
for name, p in zip(names, params):
|
||||
if as_params:
|
||||
p = nn.Parameter(p)
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
mutable_model = copy.deepcopy(model)
|
||||
load_weights(mutable_model, descriptors, weights)
|
||||
return mutable_model(*data)
|
||||
|
||||
return weights, fun, descriptors
|
@ -66,19 +66,16 @@ def train_step_fn(weights, batch, targets, lr=0.2):
|
||||
with torch.no_grad():
|
||||
for grad_weight, weight in zip(grad_weights, weights):
|
||||
new_weights.append(weight - grad_weight * lr)
|
||||
# NB: return looks weird because torch.vmap must return Tensors
|
||||
return (loss, *new_weights)
|
||||
|
||||
return loss, new_weights
|
||||
|
||||
def unpack(train_result):
|
||||
return train_result[0], train_result[1:]
|
||||
|
||||
# Step 4: Let's verify this actually trains.
|
||||
# We should see the loss decrease.
|
||||
def step4():
|
||||
global weights
|
||||
for i in range(2000):
|
||||
loss, weights = unpack(train_step_fn(weights, points, labels))
|
||||
loss, weights = train_step_fn(weights, points, labels)
|
||||
if i % 100 == 0:
|
||||
print(loss)
|
||||
|
||||
@ -100,7 +97,7 @@ def step6():
|
||||
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
|
||||
batched_weights = init_fn(num_models=2)
|
||||
for i in range(2000):
|
||||
loss, batched_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))
|
||||
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
|
||||
if i % 200 == 0:
|
||||
print(loss)
|
||||
|
||||
|
@ -111,7 +111,7 @@ def main():
|
||||
# Given this module we've created, rip out the parameters and buffers
|
||||
# and return a functional version of the module. `fnet` is stateless
|
||||
# and can be called with `fnet(params, buffers, args, kwargs)`
|
||||
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
|
||||
params, buffers, fnet, _, _ = make_functional_with_buffers(net)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
|
@ -3,10 +3,9 @@ from . import _C
|
||||
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers, load_state
|
||||
from ._src.python_key import wrap_key, WrapModule, PythonTensor, pythonkey_trace
|
||||
|
||||
|
||||
# Monkeypatching lol
|
||||
_old_cross_entropy = torch.nn.functional.cross_entropy
|
||||
|
||||
|
@ -76,7 +76,53 @@ def load_buffers(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a
|
||||
for name, p in zip(names, params):
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
|
||||
def load_state(
|
||||
model: nn.Module,
|
||||
weights: List[Tensor], weight_names: List[str],
|
||||
buffers=(), buffer_names=()):
|
||||
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
|
||||
|
||||
load_state takes `weights` and `buffers` and assigns them to the model.
|
||||
This is the inverse operation of `make_functional`.
|
||||
"""
|
||||
assert len(weight_names) == len(weights)
|
||||
load_weights(model, weight_names, weights)
|
||||
if len(buffers) > 0:
|
||||
assert len(buffer_names) == len(buffers)
|
||||
load_buffers(model, buffer_names, buffers)
|
||||
return model
|
||||
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
"""make_functional(model) -> weights, func, weight_names
|
||||
|
||||
Given an nn.Module, make_functional extracts the state (weights)
|
||||
and returns a functional version of the model, `func`. This makes
|
||||
it so that it is possible use transforms over the parameters of
|
||||
`model`.
|
||||
|
||||
`func` can be invoked as follows:
|
||||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, func, _ = make_functional(model)
|
||||
func(weights, (x,))
|
||||
```
|
||||
|
||||
And here is an example of applying the grad transform:
|
||||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, _, func = make_functional(model)
|
||||
grad_weights = grad(func)(weights, (x,))
|
||||
```
|
||||
|
||||
To put the state back into a model, use `load_state`.
|
||||
"""
|
||||
buffers = list(model.buffers())
|
||||
if len(buffers) > 0:
|
||||
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
|
||||
'make_functional_with_buffers(model) instead.')
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
@ -87,6 +133,30 @@ def make_functional(model: nn.Module):
|
||||
return weights, fun, descriptors
|
||||
|
||||
def make_functional_with_buffers(model: nn.Module):
|
||||
"""make_functional_with_buffers(model) -> weights, buffers, func, weight_names, buffer_names
|
||||
|
||||
Given an nn.Module, make_functional_with_buffers extracts the state (weights and buffers)
|
||||
and returns a functional version of the model, `func`.
|
||||
|
||||
`func` can be invoked as follows:
|
||||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, buffers, func, _, _ = make_functional_with_buffers(model)
|
||||
func(weights, buffers, (x,))
|
||||
```
|
||||
|
||||
And here is an example of applying the grad transform:
|
||||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, buffers, func, _, _ = make_functional_with_buffers(model)
|
||||
func(weights, buffers, (x,))
|
||||
grad_weights = grad(func)(weights, buffers, (x,))
|
||||
```
|
||||
|
||||
To put the state back into a model, use `load_state`.
|
||||
"""
|
||||
weights, weight_descriptors = extract_weights(model)
|
||||
buffers, buf_descriptors = extract_buffers(model)
|
||||
|
||||
|
Reference in New Issue
Block a user