mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from typing import List, Tuple
|
|
import copy
|
|
|
|
# Utilities to make nn.Module "functional"
|
|
# In particular the goal is to be able to provide a function that takes as input
|
|
# the parameters and evaluate the nn.Module using fixed inputs.
|
|
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
|
"""
|
|
Deletes the attribute specified by the given list of names.
|
|
For example, to delete the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'])
|
|
"""
|
|
if len(names) == 1:
|
|
delattr(obj, names[0])
|
|
else:
|
|
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
|
|
|
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
|
"""
|
|
Set the attribute specified by the given list of names to value.
|
|
For example, to set the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
|
"""
|
|
if len(names) == 1:
|
|
setattr(obj, names[0], value)
|
|
else:
|
|
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
|
|
|
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
|
"""
|
|
This function removes all the Parameters from the model and
|
|
return them as a tuple as well as their original attribute names.
|
|
The weights must be re-loaded with `load_weights` before the model
|
|
can be used again.
|
|
Note that this function modifies the model in place and after this
|
|
call, mod.parameters() will be empty.
|
|
"""
|
|
orig_params = tuple(mod.parameters())
|
|
# Remove all the parameters in the model
|
|
names = []
|
|
for name, p in list(mod.named_parameters()):
|
|
_del_nested_attr(mod, name.split("."))
|
|
names.append(name)
|
|
|
|
# Make params regular Tensors instead of nn.Parameter
|
|
params = tuple(p for p in orig_params)
|
|
return params, names
|
|
|
|
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
|
"""
|
|
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
|
Note that the `params` are regular Tensors (that can have history) and so are left
|
|
as Tensors. This means that mod.parameters() will still be empty after this call.
|
|
"""
|
|
for name, p in zip(names, params):
|
|
if as_params:
|
|
p = nn.Parameter(p)
|
|
_set_nested_attr(mod, name.split("."), p)
|
|
|
|
def make_functional(model: nn.Module):
|
|
weights, descriptors = extract_weights(model)
|
|
|
|
def fun(weights, data):
|
|
mutable_model = copy.deepcopy(model)
|
|
load_weights(mutable_model, descriptors, weights)
|
|
return mutable_model(*data)
|
|
|
|
return weights, fun, descriptors
|