[functorch] Add a short-term parameter tying error message (pytorch/functorch#552)

This commit is contained in:
Richard Zou
2022-03-02 10:47:26 -05:00
committed by Jon Janzen
parent 29d8a62da5
commit 612df22e7d
4 changed files with 79 additions and 34 deletions

View File

@ -9,6 +9,7 @@ from torch.nn.utils import _stateless
from functorch._C import CompileCache
from .decompositions import register_decomposition
from .partitioners import default_partition
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional
pytree._register_pytree_node(
@ -474,40 +475,6 @@ def clear_compile_cache():
compile_cache = None
# Polyfilled from pytorch core while we figure out the `remove_duplicate` issues.
def _named_members(mod, get_members_fn, prefix='', recurse=True, remove_duplicate=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = mod.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, mod)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
if remove_duplicate:
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
def _named_parameters(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
gen = _named_members(
mod,
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
for elem in gen:
yield elem
def _named_buffers(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
gen = _named_members(
mod,
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
for elem in gen:
yield elem
def aot_module(mod, *args, **kwargs):
def functional_call(named_params, named_buffers, *args, **kwargs):
params_and_buffers = {**named_params, **named_buffers}

View File

@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
from .named_members_polyfill import _named_parameters, _named_buffers
import copy
# Utilities to make nn.Module "functional"
@ -46,6 +47,15 @@ def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
_get_nested_attr(getattr(obj, names[0]), names[1:])
def raise_parameter_tying_error():
raise RuntimeError(
"make_functional(module): we don't yet support models that "
"do parameter tying (also sometimes known as weight sharing). "
"Please try to rewrite your model by replacing all instances of the "
"tied parameter with another and/or comment your support in "
"https://github.com/pytorch/functorch/issues/446")
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
@ -55,7 +65,11 @@ def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
num_orig_params_with_duplicates = len(tuple(_named_parameters(mod, remove_duplicate=False)))
orig_params = tuple(mod.parameters())
if len(orig_params) != num_orig_params_with_duplicates:
raise_parameter_tying_error()
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
@ -91,7 +105,11 @@ def _swap_state(mod: nn.Module, split_names: List[str], elems):
def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
num_orig_params_with_duplicates = len(tuple(_named_buffers(mod, remove_duplicate=False)))
orig_params = tuple(mod.buffers())
if len(orig_params) != num_orig_params_with_duplicates:
raise_parameter_tying_error()
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_buffers()):

View File

@ -0,0 +1,32 @@
# Polyfilled from pytorch core while we figure out the `remove_duplicate` issues.
def _named_members(mod, get_members_fn, prefix='', recurse=True, remove_duplicate=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = mod.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, mod)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
if remove_duplicate:
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
def _named_parameters(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
gen = _named_members(
mod,
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
for elem in gen:
yield elem
def _named_buffers(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
gen = _named_members(
mod,
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
for elem in gen:
yield elem

View File

@ -1988,6 +1988,34 @@ class TestComposability(TestCase):
# Honestly IDK what the result here is... but at least it runs
class TestMakeFunctional(TestCase):
def test_parameter_tying(self):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.linear.bias = self.bias
self.linear_tied = self.linear
mod = Foo()
with self.assertRaisesRegex(RuntimeError, "parameter tying"):
func, params = make_functional(mod)
def test_buffer_tying(self):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.register_buffer('buffer_tied', self.buffer)
mod = Foo()
with self.assertRaisesRegex(RuntimeError, "parameter tying"):
func, params, buffers = make_functional_with_buffers(mod)
class TestExamplesCorrectness(TestCase):
def test_maml_regression(self, device):
class ThreeLayerNet(nn.Module):