mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
[functorch] Add a short-term parameter tying error message (pytorch/functorch#552)
This commit is contained in:
@ -9,6 +9,7 @@ from torch.nn.utils import _stateless
|
||||
from functorch._C import CompileCache
|
||||
from .decompositions import register_decomposition
|
||||
from .partitioners import default_partition
|
||||
from .named_members_polyfill import _named_parameters, _named_buffers
|
||||
from typing import Callable, List, Dict, Any, Tuple, Optional
|
||||
|
||||
pytree._register_pytree_node(
|
||||
@ -474,40 +475,6 @@ def clear_compile_cache():
|
||||
compile_cache = None
|
||||
|
||||
|
||||
# Polyfilled from pytorch core while we figure out the `remove_duplicate` issues.
|
||||
def _named_members(mod, get_members_fn, prefix='', recurse=True, remove_duplicate=True):
|
||||
r"""Helper method for yielding various names + members of modules."""
|
||||
memo = set()
|
||||
modules = mod.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, mod)]
|
||||
for module_prefix, module in modules:
|
||||
members = get_members_fn(module)
|
||||
for k, v in members:
|
||||
if v is None or v in memo:
|
||||
continue
|
||||
if remove_duplicate:
|
||||
memo.add(v)
|
||||
name = module_prefix + ('.' if module_prefix else '') + k
|
||||
yield name, v
|
||||
|
||||
|
||||
def _named_parameters(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
|
||||
gen = _named_members(
|
||||
mod,
|
||||
lambda module: module._parameters.items(),
|
||||
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
|
||||
def _named_buffers(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
|
||||
gen = _named_members(
|
||||
mod,
|
||||
lambda module: module._buffers.items(),
|
||||
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
|
||||
def aot_module(mod, *args, **kwargs):
|
||||
def functional_call(named_params, named_buffers, *args, **kwargs):
|
||||
params_and_buffers = {**named_params, **named_buffers}
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
from .named_members_polyfill import _named_parameters, _named_buffers
|
||||
import copy
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
@ -46,6 +47,15 @@ def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
_get_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
|
||||
def raise_parameter_tying_error():
|
||||
raise RuntimeError(
|
||||
"make_functional(module): we don't yet support models that "
|
||||
"do parameter tying (also sometimes known as weight sharing). "
|
||||
"Please try to rewrite your model by replacing all instances of the "
|
||||
"tied parameter with another and/or comment your support in "
|
||||
"https://github.com/pytorch/functorch/issues/446")
|
||||
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
@ -55,7 +65,11 @@ def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
Note that this function modifies the model in place and after this
|
||||
call, mod.parameters() will be empty.
|
||||
"""
|
||||
num_orig_params_with_duplicates = len(tuple(_named_parameters(mod, remove_duplicate=False)))
|
||||
orig_params = tuple(mod.parameters())
|
||||
if len(orig_params) != num_orig_params_with_duplicates:
|
||||
raise_parameter_tying_error()
|
||||
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_parameters()):
|
||||
@ -91,7 +105,11 @@ def _swap_state(mod: nn.Module, split_names: List[str], elems):
|
||||
|
||||
|
||||
def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
num_orig_params_with_duplicates = len(tuple(_named_buffers(mod, remove_duplicate=False)))
|
||||
orig_params = tuple(mod.buffers())
|
||||
if len(orig_params) != num_orig_params_with_duplicates:
|
||||
raise_parameter_tying_error()
|
||||
|
||||
# Remove all the parameters in the model
|
||||
names = []
|
||||
for name, p in list(mod.named_buffers()):
|
||||
|
||||
32
functorch/functorch/_src/named_members_polyfill.py
Normal file
32
functorch/functorch/_src/named_members_polyfill.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Polyfilled from pytorch core while we figure out the `remove_duplicate` issues.
|
||||
def _named_members(mod, get_members_fn, prefix='', recurse=True, remove_duplicate=True):
|
||||
r"""Helper method for yielding various names + members of modules."""
|
||||
memo = set()
|
||||
modules = mod.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, mod)]
|
||||
for module_prefix, module in modules:
|
||||
members = get_members_fn(module)
|
||||
for k, v in members:
|
||||
if v is None or v in memo:
|
||||
continue
|
||||
if remove_duplicate:
|
||||
memo.add(v)
|
||||
name = module_prefix + ('.' if module_prefix else '') + k
|
||||
yield name, v
|
||||
|
||||
|
||||
def _named_parameters(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
|
||||
gen = _named_members(
|
||||
mod,
|
||||
lambda module: module._parameters.items(),
|
||||
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
|
||||
def _named_buffers(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True):
|
||||
gen = _named_members(
|
||||
mod,
|
||||
lambda module: module._buffers.items(),
|
||||
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
@ -1988,6 +1988,34 @@ class TestComposability(TestCase):
|
||||
# Honestly IDK what the result here is... but at least it runs
|
||||
|
||||
|
||||
class TestMakeFunctional(TestCase):
|
||||
def test_parameter_tying(self):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.randn(3))
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.linear.bias = self.bias
|
||||
self.linear_tied = self.linear
|
||||
|
||||
mod = Foo()
|
||||
with self.assertRaisesRegex(RuntimeError, "parameter tying"):
|
||||
func, params = make_functional(mod)
|
||||
|
||||
def test_buffer_tying(self):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.randn(3))
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.register_buffer('buffer', torch.randn(3))
|
||||
self.register_buffer('buffer_tied', self.buffer)
|
||||
|
||||
mod = Foo()
|
||||
with self.assertRaisesRegex(RuntimeError, "parameter tying"):
|
||||
func, params, buffers = make_functional_with_buffers(mod)
|
||||
|
||||
|
||||
class TestExamplesCorrectness(TestCase):
|
||||
def test_maml_regression(self, device):
|
||||
class ThreeLayerNet(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user