mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
[stateless] add weight tying support (#90477)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90477 Approved by: https://github.com/zou3519
This commit is contained in:
@ -22,6 +22,17 @@ class MockModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.buffer
|
||||
|
||||
class MockTiedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.tied_bias = self.l1.bias
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.register_buffer('tied_buffer', self.buffer)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
|
||||
|
||||
|
||||
class TestStatelessFunctionalAPI(TestCase):
|
||||
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
|
||||
@ -156,7 +167,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
'l1.m.buffer': buffer}
|
||||
prev_weight = module.l1.weight.clone()
|
||||
prev_buffer = module.buffer.clone()
|
||||
res = functional_call(module, parameters, x)
|
||||
res = functional_call(module, parameters, x, tie_weights=False)
|
||||
self.assertEqual(x, res)
|
||||
# check that the weights remain unmodified and were correctly accesed
|
||||
cur_weight = module.l1.weight
|
||||
@ -217,6 +228,46 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
module = MockModule()
|
||||
module.tied_bias = module.l1.bias
|
||||
module.register_buffer("tied_buffer", module.buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_tie_weights(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]],)
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
out = functional_call(module, parameters, x, tie_weights=True)
|
||||
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
|
||||
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_tie_some_weights(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]],)
|
||||
buffer = torch.tensor([3.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
out = stateless.functional_call(module, parameters, x, tie_weights=True)
|
||||
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_tied_weights_errors(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[1.0]],)
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
@ -225,23 +276,41 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x))
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
|
||||
|
||||
# if tied values are the same tensors, shouldn't warn
|
||||
parameters['tied_bias'] = bias
|
||||
parameters['tied_buffer'] = buffer
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x))
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
|
||||
del parameters['tied_bias']
|
||||
del parameters['tied_buffer']
|
||||
|
||||
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
|
||||
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
|
||||
parameters['tied_bias'] = torch.tensor([5.0])
|
||||
functional_call(module, parameters, x)
|
||||
functional_call(module, parameters, x, tie_weights=True)
|
||||
del parameters['tied_bias']
|
||||
|
||||
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
|
||||
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
functional_call(module, parameters, x)
|
||||
functional_call(module, parameters, x, tie_weights=True)
|
||||
|
||||
|
||||
def test_tied_weights_no_error_without_flag(self):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[1.0]],)
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
|
||||
parameters['tied_bias'] = torch.tensor([5.0])
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
|
||||
del parameters['tied_bias']
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
|
||||
@ -12,6 +12,8 @@ def functional_call(
|
||||
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
):
|
||||
r"""Performs a functional call on the module by replacing the module parameters
|
||||
and buffers with the provided ones.
|
||||
@ -36,6 +38,21 @@ def functional_call(
|
||||
>>> print(mod.foo) # tensor(0.)
|
||||
>>> print(a['foo']) # tensor(1.)
|
||||
|
||||
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
|
||||
tie_weights flag.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = {'foo': torch.zeros(())}
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
|
||||
>>> print(mod.foo) # tensor(1.)
|
||||
>>> mod(torch.zeros(())) # tensor(2.)
|
||||
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
|
||||
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
|
||||
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
|
||||
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
|
||||
|
||||
An example of passing mutliple dictionaries
|
||||
|
||||
.. code-block:: python
|
||||
@ -88,6 +105,10 @@ def functional_call(
|
||||
be used together
|
||||
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
|
||||
kwargs (dict): keyword arguments to be passed to the module call
|
||||
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
|
||||
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
|
||||
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
|
||||
buffers unless the values passed for both weights are the same. Default: True.
|
||||
|
||||
Returns:
|
||||
Any: the result of calling ``module``.
|
||||
@ -102,7 +123,7 @@ def functional_call(
|
||||
|
||||
parameters_and_buffers = {k: v for d in parameter_and_buffer_dicts for k, v in d.items()}
|
||||
|
||||
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs)
|
||||
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs, tie_weights=tie_weights)
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import warnings
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -37,21 +36,84 @@ def _change_class(module, params_and_buffers) -> None:
|
||||
module._orig_class = cls
|
||||
|
||||
|
||||
def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
|
||||
if old_val not in replaced_tensors_map:
|
||||
replaced_tensors_map[old_val] = new_val
|
||||
elif replaced_tensors_map[old_val] is not new_val:
|
||||
warnings.warn("functional_call was passed multiple values for tied weights. "
|
||||
"This behavior is deprecated and will be an error in future versions")
|
||||
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
|
||||
"""
|
||||
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
|
||||
|
||||
Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
|
||||
|
||||
ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
|
||||
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
|
||||
{'tied_foo': 'foo'}.
|
||||
|
||||
ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
|
||||
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
|
||||
|
||||
ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
|
||||
user only passed a new value for 'bias', this looks returns: {'bias': set()}
|
||||
|
||||
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
|
||||
this returned dictionary.
|
||||
"""
|
||||
|
||||
# The basic algorithm looks like:
|
||||
# - index all weights by their original tensor value to find tied weights
|
||||
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
|
||||
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
|
||||
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
|
||||
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
|
||||
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
|
||||
|
||||
names = params_and_buffers.keys()
|
||||
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
|
||||
|
||||
# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
|
||||
# part at the end it's (used_name, (tied_names)).
|
||||
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
|
||||
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
|
||||
def add_to_name_map(n: str, t: torch.Tensor):
|
||||
# if the tensor hasn't been seen before, add it to the map
|
||||
if t not in weight_to_name_and_tied_names:
|
||||
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
|
||||
return
|
||||
|
||||
# if the name is not used by the user, we add it to the tied set
|
||||
if n not in names:
|
||||
weight_to_name_and_tied_names[t][1].add(n)
|
||||
return
|
||||
|
||||
# check that the user didn't pass two different tensors for the same tied weight
|
||||
first_seen_name = weight_to_name_and_tied_names[t][0]
|
||||
|
||||
# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
|
||||
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
|
||||
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
|
||||
return
|
||||
|
||||
raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
|
||||
"Consider using tie_weights=False")
|
||||
|
||||
tensor: Tensor
|
||||
for name, tensor in module.named_parameters(remove_duplicate=False):
|
||||
add_to_name_map(name, tensor)
|
||||
|
||||
for name, tensor in module.named_buffers(remove_duplicate=False):
|
||||
add_to_name_map(name, tensor)
|
||||
|
||||
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
|
||||
tied_weights_to_given_name = {}
|
||||
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
|
||||
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
|
||||
continue
|
||||
for tied_name in tied_names:
|
||||
tied_weights_to_given_name[tied_name] = name_given_by_user
|
||||
return tied_weights_to_given_name
|
||||
|
||||
|
||||
def _create_swap_params(params_and_buffers, replaced_tensors_map):
|
||||
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
|
||||
def _create_swap_params(params_and_buffers):
|
||||
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
|
||||
# Changes the module class to get a new __getattr__ dunder method
|
||||
# that looks for the reparametrized tensor
|
||||
if hasattr(module, tensor_name):
|
||||
old_val = getattr(module, tensor_name)
|
||||
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
|
||||
if hasattr(module, "_attr_to_path"):
|
||||
module._attr_to_path[tensor_name] = full_path
|
||||
else:
|
||||
@ -72,12 +134,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
|
||||
def _reparametrize_module(
|
||||
module: 'torch.nn.Module',
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
tie_weights: bool = False,
|
||||
) -> Iterator[None]:
|
||||
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
|
||||
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
|
||||
for name, tensor in parameters_and_buffers.items():
|
||||
_apply_func_submodules(
|
||||
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
|
||||
_create_swap_params(parameters_and_buffers),
|
||||
module, name.split("."), name, (tensor,))
|
||||
for tied_name, user_given_name in tied_weights_map.items():
|
||||
_apply_func_submodules(
|
||||
_create_swap_params(parameters_and_buffers),
|
||||
module, tied_name.split("."), user_given_name, (None,))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -105,6 +172,8 @@ def functional_call(
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
):
|
||||
r"""Performs a functional call on the module by replacing the module parameters
|
||||
and buffers with the provided ones.
|
||||
@ -128,12 +197,31 @@ def functional_call(
|
||||
>>> print(mod.foo) # tensor(0.)
|
||||
>>> print(a['foo']) # tensor(1.)
|
||||
|
||||
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
|
||||
tie_weights flag.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = {'foo': torch.zeros(())}
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
|
||||
>>> print(mod.foo) # tensor(1.)
|
||||
>>> mod(torch.zeros(())) # tensor(2.)
|
||||
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
|
||||
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
|
||||
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
|
||||
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the module to call
|
||||
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
|
||||
the module call.
|
||||
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
|
||||
kwargs (dict): keyword arguments to be passed to the module call
|
||||
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
|
||||
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
|
||||
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
|
||||
buffers unless the values passed for both weights are the same. Default: True.
|
||||
|
||||
Returns:
|
||||
Any: the result of calling ``module``.
|
||||
@ -151,7 +239,7 @@ def functional_call(
|
||||
raise RuntimeError("The stateless API can't be used with Jitted modules")
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
with _reparametrize_module(module, parameters_and_buffers):
|
||||
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
|
||||
if isinstance(args, tuple):
|
||||
out = module(*args, **kwargs)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user