[stateless] add weight tying support (#90477)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90477
Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2023-01-10 17:02:48 -05:00
committed by PyTorch MergeBot
parent e03ac0ee8c
commit 8b3c4bc481
3 changed files with 202 additions and 24 deletions

View File

@ -22,6 +22,17 @@ class MockModule(torch.nn.Module):
def forward(self, x):
return self.l1(x) + self.buffer
class MockTiedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.tied_bias = self.l1.bias
self.register_buffer('buffer', torch.ones(1))
self.register_buffer('tied_buffer', self.buffer)
def forward(self, x):
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
class TestStatelessFunctionalAPI(TestCase):
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
@ -156,7 +167,7 @@ class TestStatelessFunctionalAPI(TestCase):
'l1.m.buffer': buffer}
prev_weight = module.l1.weight.clone()
prev_buffer = module.buffer.clone()
res = functional_call(module, parameters, x)
res = functional_call(module, parameters, x, tie_weights=False)
self.assertEqual(x, res)
# check that the weights remain unmodified and were correctly accesed
cur_weight = module.l1.weight
@ -217,6 +228,46 @@ class TestStatelessFunctionalAPI(TestCase):
module = MockModule()
module.tied_bias = module.l1.bias
module.register_buffer("tied_buffer", module.buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
out = functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_some_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
buffer = torch.tensor([3.0])
parameters = {'l1.weight': weight,
'buffer': buffer}
x = torch.randn(1, 1)
out = stateless.functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_tied_weights_errors(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
@ -225,23 +276,41 @@ class TestStatelessFunctionalAPI(TestCase):
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: functional_call(module, parameters, x))
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
# if tied values are the same tensors, shouldn't warn
parameters['tied_bias'] = bias
parameters['tied_buffer'] = buffer
self.assertNotWarn(lambda: functional_call(module, parameters, x))
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
del parameters['tied_bias']
del parameters['tied_buffer']
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
parameters['tied_bias'] = torch.tensor([5.0])
functional_call(module, parameters, x)
functional_call(module, parameters, x, tie_weights=True)
del parameters['tied_bias']
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
parameters['tied_buffer'] = torch.tensor([5.0])
functional_call(module, parameters, x)
functional_call(module, parameters, x, tie_weights=True)
def test_tied_weights_no_error_without_flag(self):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
parameters['tied_bias'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
del parameters['tied_bias']
parameters['tied_buffer'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),

View File

@ -12,6 +12,8 @@ def functional_call(
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
@ -36,6 +38,21 @@ def functional_call(
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
An example of passing mutliple dictionaries
.. code-block:: python
@ -88,6 +105,10 @@ def functional_call(
be used together
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
Returns:
Any: the result of calling ``module``.
@ -102,7 +123,7 @@ def functional_call(
parameters_and_buffers = {k: v for d in parameter_and_buffer_dicts for k, v in d.items()}
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs)
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs, tie_weights=tie_weights)
@exposed_in("torch.func")

View File

@ -1,6 +1,5 @@
import warnings
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
import torch
from torch import Tensor
@ -37,21 +36,84 @@ def _change_class(module, params_and_buffers) -> None:
module._orig_class = cls
def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
if old_val not in replaced_tensors_map:
replaced_tensors_map[old_val] = new_val
elif replaced_tensors_map[old_val] is not new_val:
warnings.warn("functional_call was passed multiple values for tied weights. "
"This behavior is deprecated and will be an error in future versions")
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
"""
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
{'tied_foo': 'foo'}.
ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
user only passed a new value for 'bias', this looks returns: {'bias': set()}
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
this returned dictionary.
"""
# The basic algorithm looks like:
# - index all weights by their original tensor value to find tied weights
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
names = params_and_buffers.keys()
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
# part at the end it's (used_name, (tied_names)).
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
def add_to_name_map(n: str, t: torch.Tensor):
# if the tensor hasn't been seen before, add it to the map
if t not in weight_to_name_and_tied_names:
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
return
# if the name is not used by the user, we add it to the tied set
if n not in names:
weight_to_name_and_tied_names[t][1].add(n)
return
# check that the user didn't pass two different tensors for the same tied weight
first_seen_name = weight_to_name_and_tied_names[t][0]
# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
return
raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
"Consider using tie_weights=False")
tensor: Tensor
for name, tensor in module.named_parameters(remove_duplicate=False):
add_to_name_map(name, tensor)
for name, tensor in module.named_buffers(remove_duplicate=False):
add_to_name_map(name, tensor)
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
tied_weights_to_given_name = {}
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
continue
for tied_name in tied_names:
tied_weights_to_given_name[tied_name] = name_given_by_user
return tied_weights_to_given_name
def _create_swap_params(params_and_buffers, replaced_tensors_map):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
def _create_swap_params(params_and_buffers):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, tensor_name):
old_val = getattr(module, tensor_name)
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
@ -72,12 +134,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
def _reparametrize_module(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
) -> Iterator[None]:
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
_create_swap_params(parameters_and_buffers),
module, name.split("."), name, (tensor,))
for tied_name, user_given_name in tied_weights_map.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
module, tied_name.split("."), user_given_name, (None,))
try:
yield
finally:
@ -105,6 +172,8 @@ def functional_call(
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
@ -128,12 +197,31 @@ def functional_call(
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
the module call.
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
Returns:
Any: the result of calling ``module``.
@ -151,7 +239,7 @@ def functional_call(
raise RuntimeError("The stateless API can't be used with Jitted modules")
if kwargs is None:
kwargs = {}
with _reparametrize_module(module, parameters_and_buffers):
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else: