[functorch][nn] Refactor NN stateless APIs by swapping module tensors (#92536)

- Fixes #92295
- Resolves #86708
- Resolves #92153
- Closes #92401
- Closes #92218

- Requires #91579

Refactor NN stateless APIs by swapping module tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92536
Approved by: https://github.com/jbschlosser
This commit is contained in:
Xuehai Pan
2023-02-08 17:31:38 +00:00
committed by PyTorch MergeBot
parent 3fd46a2f9c
commit b8de1cf007
6 changed files with 1089 additions and 223 deletions

View File

@ -835,9 +835,10 @@ include_patterns = [
'torch/_*.py',
'torch/testing/_internal/opinfo/**/*.py',
'torchgen/**/*.py',
'functorch/functorch/_src/aot_autograd.py',
'functorch/functorch/_src/compilers.py',
'torch/_functorch/make_functional.py',
'torch/_functorch/functional_call.py',
'torch/nn/utils/_named_member_accessor.py',
'torch/nn/utils/stateless.py',
'torch/testing/*.py',
'torch/distributed/fsdp/**/*.py',
'test/distributed/fsdp/**/*.py',

View File

@ -1,12 +1,13 @@
# Owner(s): ["module: nn"]
import unittest
import sys
import contextlib
import os
import re
import subprocess
import sys
import unittest
import torch
import torch.nn.utils.stateless as stateless
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \
@ -18,10 +19,12 @@ class MockModule(torch.nn.Module):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.register_buffer('buffer', torch.ones(1))
self.foo = 0.0
def forward(self, x):
return self.l1(x) + self.buffer
class MockTiedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -65,6 +68,29 @@ class TestStatelessFunctionalAPI(TestCase):
self.assertEqual(cur_weight, prev_weight)
self.assertEqual(cur_buffer, prev_buffer)
@contextlib.contextmanager
def _ensure_module_unchanged(self, module, message):
orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers())
orig_tensors = orig_parameters + orig_buffers
orig_tensors_values = tuple(t.clone() for t in orig_tensors)
try:
yield module
finally:
parameters, buffers = tuple(module.parameters()), tuple(module.buffers())
self.assertTrue(
len(parameters) == len(orig_parameters)
and len(buffers) == len(orig_buffers)
and all(
t1 is t2 and torch.allclose(t1, t3)
for t1, t2, t3 in zip(
orig_tensors,
parameters + buffers,
orig_tensors_values,
)
),
message,
)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
@ -201,7 +227,7 @@ class TestStatelessFunctionalAPI(TestCase):
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_module_fail_reset_to_original(self, functional_call):
def test_reparametrize_module_fail_reset_to_original(self, functional_call):
module = MockModule()
torch.nn.utils.parametrizations.spectral_norm(module.l1)
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
@ -220,6 +246,161 @@ class TestStatelessFunctionalAPI(TestCase):
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
self.assertEqual(orig_sn_weight, module.l1.weight)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparametrize_some_weights(self, functional_call):
module = MockModule()
weight = torch.tensor([[2.0]])
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
extra = torch.tensor([1.0])
parameters = {'l1.weight': weight}
x = torch.randn(1, 1)
out = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
parameters = {'l1.weight': weight,
'extra': extra}
x = torch.randn(1, 1)
out = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparametrize_strict(self, functional_call):
module = MockModule()
weight = torch.tensor([[2.0]])
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
extra = torch.tensor([1.0])
# All weights no error
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a successful call',
):
out = functional_call(module, parameters, x, strict=True)
self.assertEqual(out, x * weight + bias + buffer)
# Some weights
parameters = {'l1.weight': weight}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Missing key(s): 'buffer', 'l1.bias'."),
):
out = functional_call(module, parameters, x, strict=True)
# Extra keys
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer,
'extra': extra}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Unexpected key(s): 'extra'."),
):
out = functional_call(module, parameters, x, strict=True)
# Some weights with extra keys
parameters = {'l1.weight': weight,
'extra': extra}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."),
):
out = functional_call(module, parameters, x, strict=True)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparametrize_special(self, functional_call):
class NonTensor:
def __repr__(self):
return f'<{self.__class__.__name__}>'
module = MockModule()
weight = torch.tensor([[2.0]])
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
non_tensor = NonTensor()
# Set to None
parameters = {'l1.weight': weight,
'l1.bias': None,
'buffer': buffer}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a successful call',
):
out = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + buffer)
# Set non-tensor
parameters = {'l1.weight': non_tensor}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
TypeError,
re.escape("<NonTensor> is not an instance of torch.Tensor"),
):
out = functional_call(module, parameters, x)
# Set non-tensor attribute
parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
TypeError,
re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"),
):
out = functional_call(module, parameters, x)
# Set non-exist submodule
parameters = {'l1.weight': weight,
'l2.bias': bias}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
AttributeError,
re.escape("MockModule has no attribute `l2`"),
):
out = functional_call(module, parameters, x)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
@ -233,11 +414,12 @@ class TestStatelessFunctionalAPI(TestCase):
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_weights(self, functional_call):
def test_reparametrize_tie_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
weight = torch.tensor([[2.0]])
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
extra = torch.tensor([1.0])
parameters = {'l1.weight': weight,
'l1.bias': bias,
@ -246,14 +428,21 @@ class TestStatelessFunctionalAPI(TestCase):
out = functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer,
'extra': extra}
x = torch.randn(1, 1)
out = functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_some_weights(self, functional_call):
def test_reparametrize_tie_some_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
weight = torch.tensor([[2.0]])
buffer = torch.tensor([3.0])
parameters = {'l1.weight': weight,
@ -268,7 +457,7 @@ class TestStatelessFunctionalAPI(TestCase):
])
def test_tied_weights_errors(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
weight = torch.tensor([[1.0]])
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
@ -285,19 +474,24 @@ class TestStatelessFunctionalAPI(TestCase):
del parameters['tied_bias']
del parameters['tied_buffer']
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
with self.assertRaisesRegex(
ValueError,
re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"),
):
parameters['tied_bias'] = torch.tensor([5.0])
functional_call(module, parameters, x, tie_weights=True)
del parameters['tied_bias']
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
with self.assertRaisesRegex(
ValueError,
re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"),
):
parameters['tied_buffer'] = torch.tensor([5.0])
functional_call(module, parameters, x, tie_weights=True)
def test_tied_weights_no_error_without_flag(self):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
weight = torch.tensor([[1.0]])
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
@ -312,6 +506,105 @@ class TestStatelessFunctionalAPI(TestCase):
parameters['tied_buffer'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparametrize_tie_weights_strict(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]])
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
extra = torch.tensor([1.0])
# Tie weights no error
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a successful call',
):
out = functional_call(module, parameters, x, tie_weights=True, strict=True)
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
# Tie weights without flag
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
):
out = functional_call(module, parameters, x, tie_weights=False, strict=True)
# Tie some weights
parameters = {'l1.weight': weight,
'buffer': buffer}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
):
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
# Tie weights with extra keys
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer,
'extra': extra}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Unexpected key(s): 'extra'."),
):
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
# Tie weights with extra keys and without flag
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer,
'extra': extra}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
):
out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True)
# Tie some weights with extra keys
parameters = {'l1.weight': weight,
'buffer': buffer,
'extra': extra}
x = torch.randn(1, 1)
with self._ensure_module_unchanged(
module,
'the module should not have been modified by a failed call',
):
with self.assertRaisesRegex(
RuntimeError,
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
):
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
@ -320,17 +613,89 @@ class TestStatelessFunctionalAPI(TestCase):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('foo', torch.zeros(()))
self.register_buffer('foo', torch.tensor([0.0]))
def forward(self, x):
self.foo = self.foo + 1
return x + self.foo
a = {'foo': torch.zeros(())}
foo = torch.tensor([2.0])
x = torch.randn(1)
a = {'foo': foo}
mod = Foo()
functional_call(mod, a, torch.ones(()))
self.assertEqual(mod.foo, torch.zeros(()))
self.assertEqual(a['foo'], torch.ones(()))
functional_call(mod, a, x)
self.assertEqual(mod.foo, torch.tensor([0.0]))
self.assertEqual(a['foo'], torch.tensor([3.0]))
self.assertEqual(foo, torch.tensor([2.0]))
self.assertTrue(a['foo'] is not foo)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_in_place_operator(self, functional_call):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('foo', torch.tensor([0.0]))
def forward(self, x):
self.foo.add_(1)
return x + self.foo
foo = torch.tensor([2.0])
x = torch.randn(1)
a = {'foo': foo}
mod = Foo()
functional_call(mod, a, x)
self.assertEqual(mod.foo, torch.tensor([0.0]))
self.assertEqual(a['foo'], torch.tensor([3.0]))
self.assertEqual(foo, torch.tensor([3.0]))
self.assertTrue(a['foo'] is foo)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_setattr_strict(self, functional_call):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
assert not hasattr(self, 'extra')
def forward(self, x):
return x + self.extra
a = {'extra': torch.zeros(())}
mod = Bar()
self.assertTrue(not hasattr(mod, 'extra'))
out = functional_call(mod, a, torch.ones(()))
self.assertEqual(out, torch.ones(()))
self.assertTrue(not hasattr(mod, 'extra'))
a = {'extra': torch.zeros(())}
with self.assertRaisesRegex(
RuntimeError,
re.escape("Unexpected key(s): 'extra'."),
):
out = functional_call(mod, a, torch.ones(()), strict=True)
self.assertTrue(not hasattr(mod, 'extra'))
a = {}
with self.assertRaisesRegex(
AttributeError,
re.escape("'Bar' object has no attribute 'extra'"),
):
out = functional_call(mod, a, torch.ones(()))
self.assertTrue(not hasattr(mod, 'extra'))
a = {}
with self.assertRaisesRegex(
AttributeError,
re.escape("'Bar' object has no attribute 'extra'"),
):
out = functional_call(mod, a, torch.ones(()), strict=True)
self.assertTrue(not hasattr(mod, 'extra'))
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
@ -355,7 +720,6 @@ class TestStatelessFunctionalAPI(TestCase):
res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp})
self.assertEqual(res, res_1)
def test_functional_call_tuple_dicts(self):
mod = MockModule()
x = torch.rand((1, 1))
@ -375,15 +739,121 @@ class TestStatelessFunctionalAPI(TestCase):
res = torch.func.functional_call(mod, a, x)
self.assertEqual(res, x + 1)
def test_functional_call_multiple_dicts_error(self):
mod = MockModule()
x = torch.rand((1, 1))
parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))}
repeated_parameters = {'l1.weight': torch.ones((1, 1))}
with self.assertRaisesRegex(ValueError, "l1.weight appeared in multiple dictionaries"):
with self.assertRaisesRegex(
ValueError,
re.escape("['l1.weight'] appeared in multiple dictionaries"),
):
torch.func.functional_call(mod, (parameters, repeated_parameters), x)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_functional_call_member_reference(self, functional_call):
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.register_buffer('buffer', torch.ones(1))
def forward(self, x):
parameters = tuple(self.parameters())
buffers = tuple(self.buffers())
return self.l1(x) + self.buffer, parameters, buffers
module = Module()
weight = torch.tensor([[2.0]])
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
extra = torch.tensor([1.0])
extra_p = torch.nn.Parameter(extra)
# All weights
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + bias + buffer)
self.assertEqual(parameters, (weight, bias))
self.assertEqual(buffers, (buffer,))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
# Some weights
parameters = {'l1.weight': weight}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
self.assertEqual(parameters, (weight, module.l1.bias))
self.assertEqual(buffers, (module.buffer,))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
# All weights with extra keys
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer,
'l1.extra': extra}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + bias + buffer)
self.assertEqual(parameters, (weight, bias))
self.assertEqual(buffers, (buffer,))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
# All weights with extra keys with parameters
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer,
'l1.extra': extra_p}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + bias + buffer)
self.assertEqual(parameters, (weight, bias, extra_p))
self.assertEqual(buffers, (buffer,))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
# Some weights with extra keys
parameters = {'l1.weight': weight,
'l1.extra': extra}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
self.assertEqual(parameters, (weight, module.l1.bias))
self.assertEqual(buffers, (module.buffer))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
# Some weights with extra keys with parameters
parameters = {'l1.weight': weight,
'l1.extra': extra_p}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
self.assertEqual(parameters, (weight, module.l1.bias, extra_p))
self.assertEqual(buffers, (module.buffer))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
# Set None
parameters = {'l1.weight': weight,
'l1.bias': None}
x = torch.randn(1, 1)
out, parameters, buffers = functional_call(module, parameters, x)
self.assertEqual(out, x * weight + module.buffer)
self.assertEqual(parameters, (weight,))
self.assertEqual(buffers, (module.buffer))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,))))
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
class TestStatelessDeprecation(TestCase):
def test_private_stateless_warns(self):

View File

@ -1,4 +1,5 @@
from typing import Dict, Union, Any, Tuple, List
from collections import Counter
from typing import Any, Dict, List, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -8,12 +9,13 @@ from torch._functorch.utils import exposed_in
@exposed_in("torch.func")
def functional_call(
module: 'torch.nn.Module',
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
@ -100,7 +102,7 @@ def functional_call(
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (Dict[str,Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
parameters_and_buffers (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
be used together
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
@ -109,25 +111,49 @@ def functional_call(
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
error. Default: False.
Returns:
Any: the result of calling ``module``.
"""
parameters_and_buffers = parameter_and_buffer_dicts if isinstance(parameter_and_buffer_dicts, dict) else {}
if isinstance(parameter_and_buffer_dicts, tuple):
key_list = [i for dct in parameter_and_buffer_dicts for i in dct.keys()]
key_set = set(key_list)
if len(key_set) != len(key_list):
repeated_key = list(filter(lambda key: key_list.count(key) > 1, key_set))[0]
raise ValueError(f"{repeated_key} appeared in multiple dictionaries; behavior of functional call is ambiguous")
if isinstance(parameter_and_buffer_dicts, dict):
parameters_and_buffers = parameter_and_buffer_dicts
elif isinstance(parameter_and_buffer_dicts, Sequence):
if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
raise ValueError(
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
)
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1]
if len(repeated_keys) > 0:
raise ValueError(
f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
)
parameters_and_buffers = {
k: v for d in parameter_and_buffer_dicts for k, v in d.items()
}
else:
raise ValueError(
f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
f"but got {type(parameter_and_buffer_dicts)}"
)
parameters_and_buffers = {k: v for d in parameter_and_buffer_dicts for k, v in d.items()}
return nn.utils.stateless._functional_call(module, parameters_and_buffers, args, kwargs, tie_weights=tie_weights)
return nn.utils.stateless._functional_call(
module,
parameters_and_buffers,
args,
kwargs,
tie_weights=tie_weights,
strict=strict,
)
@exposed_in("torch.func")
def stack_module_state(models: List[nn.Module]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def stack_module_state(
models: List[nn.Module],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""stack_module_state(models) -> params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
@ -183,29 +209,39 @@ def stack_module_state(models: List[nn.Module]) -> Tuple[Dict[str, Any], Dict[st
same mode (training vs eval).
"""
if len(models) == 0:
raise RuntimeError('stack_module_state: Expected at least one model, got 0.')
raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
if not (all(m.training for m in models) or all(not m.training for m in models)):
raise RuntimeError('stack_module_state: Expected all models to '
'have the same training/eval mode.')
raise RuntimeError(
"stack_module_state: Expected all models to have the same training/eval mode."
)
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError('stack_module_state: Expected all models to '
'be of the same class.')
all_params = [{k: v for k, v in model.named_parameters()} for model in models]
params = {k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
for k in all_params[0]}
all_buffers = [{k: v for k, v in model.named_buffers()} for model in models]
buffers = {k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
for k in all_buffers[0]}
raise RuntimeError(
"stack_module_state: Expected all models to be of the same class."
)
all_params = [dict(model.named_parameters()) for model in models]
params = {
k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
for k in all_params[0]
}
all_buffers = [dict(model.named_buffers()) for model in models]
buffers = {
k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
for k in all_buffers[0]
}
return params, buffers
def construct_stacked_leaf(tensors, name):
all_requires_grad = all([t.requires_grad for t in tensors])
none_requires_grad = all([not t.requires_grad for t in tensors])
def construct_stacked_leaf(
tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
) -> Tensor:
all_requires_grad = all(t.requires_grad for t in tensors)
none_requires_grad = all(not t.requires_grad for t in tensors)
if not all_requires_grad and not none_requires_grad:
raise RuntimeError(
f'Expected {name} from each model to have the same .requires_grad')
f"Expected {name} from each model to have the same .requires_grad"
)
result = torch.stack(tensors)
if all_requires_grad:
result = result.detach().requires_grad_()

View File

@ -21,43 +21,13 @@ from typing import (
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def _get_nested_attr(obj: nn.Module, names: List[str]) -> Tensor:
if len(names) == 1:
return getattr(obj, names[0])
else:
return _get_nested_attr(getattr(obj, names[0]), names[1:])
def raise_parameter_tying_error() -> NoReturn:
raise RuntimeError(
"make_functional(module): we don't yet support models that "
@ -71,14 +41,14 @@ def raise_parameter_tying_error() -> NoReturn:
def create_names_map(
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
) -> Dict[str, List[List[str]]]:
) -> Dict[str, List[str]]:
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': [['A']], 'B': [['B'], ['B_tied']]}.
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
named_params = dict(named_params)
tied_named_params = dict(tied_named_params)
@ -87,12 +57,12 @@ def create_names_map(
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
tensor_to_mapping: Dict[Tensor, Tuple[str, List[List[str]]]] = {}
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
tensor_to_mapping[tensor][1].append(key.split("."))
tensor_to_mapping[tensor][1].append(key)
return dict(tensor_to_mapping.values())
@ -100,18 +70,19 @@ def _extract_members(
mod: nn.Module,
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
subclass: Callable[[Tensor], Tensor],
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
all_named_members = tuple(named_members(remove_duplicate=False))
unique_named_members = tuple(named_members(remove_duplicate=True))
names_map = create_names_map(unique_named_members, all_named_members)
# Remove all the members in the model
memo = {}
accessor = NamedMemberAccessor(mod)
for name, p in all_named_members:
if p not in memo:
memo[p] = subclass(torch.empty_like(p, device="meta"))
replacement = memo[p]
_set_nested_attr(mod, name.split("."), replacement)
accessor.set_tensor(name, replacement)
if len(unique_named_members) == 0:
names, params = (), ()
@ -122,7 +93,7 @@ def _extract_members(
def extract_weights(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
@ -136,7 +107,7 @@ def extract_weights(
def extract_buffers(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
return _extract_members(mod, mod.named_buffers, lambda x: x)
@ -151,23 +122,23 @@ def load_weights(
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_del_nested_attr(mod, name.split("."))
_set_nested_attr(mod, name.split("."), p)
accessor = NamedMemberAccessor(mod)
if as_params:
params = [nn.Parameter(p) for p in params]
accessor.set_tensors(names, params)
def _swap_state(
mod: nn.Module, names_map: Dict[str, List[List[str]]], elems: Iterable[Tensor]
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
) -> List[Tensor]:
result: List[Tensor] = []
accessor = NamedMemberAccessor(mod)
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
result.append(_get_nested_attr(mod, attr_name))
_del_nested_attr(mod, attr_name)
_set_nested_attr(mod, attr_name, elem)
result.append(accessor.swap_tensor(attr_name, elem))
else:
accessor.set_tensor(attr_name, elem)
return result
@ -177,8 +148,8 @@ def load_buffers(
buffers: Sequence[Tensor],
as_params: bool = False,
) -> None:
for name, p in zip(names, buffers):
_set_nested_attr(mod, name.split("."), p)
accessor = NamedMemberAccessor(mod)
accessor.set_tensors(names, buffers)
def load_state(
@ -290,8 +261,8 @@ class FunctionalModuleWithBuffers(nn.Module):
stateless_model: nn.Module,
param_names: Tuple[str, ...],
buffer_names: Tuple[str, ...],
param_names_map: Dict[str, List[List[str]]],
buffer_names_map: Dict[str, List[List[str]]],
param_names_map: Dict[str, List[str]],
buffer_names_map: Dict[str, List[str]],
) -> None:
super(FunctionalModuleWithBuffers, self).__init__()
self.stateless_model = stateless_model
@ -345,7 +316,7 @@ class FunctionalModule(nn.Module):
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
names_map: Dict[str, List[List[str]]],
names_map: Dict[str, List[str]],
) -> None:
super(FunctionalModule, self).__init__()
self.stateless_model = stateless_model
@ -567,8 +538,7 @@ def combine_state_for_ensemble(
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError(
"combine_state_for_ensemble: Expected all models to "
"be of the same class."
"combine_state_for_ensemble: Expected all models to be of the same class."
)
funcs, params, buffers = zip(
*[make_functional_with_buffers(model) for model in models]

View File

@ -0,0 +1,341 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Iterable, List, Tuple
import torch
_MISSING: torch.Tensor = object() # type: ignore[assignment]
def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None:
if not isinstance(module, torch.nn.Module):
raise TypeError(f"{module} is not an instance of torch.nn.Module")
if not isinstance(tensor, torch.Tensor) and tensor is not None:
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
if "." in name:
raise KeyError('tensor name can\'t contain "."')
if name == "":
raise KeyError('tensor name can\'t be empty string ""')
if name in module._parameters:
module._parameters[name] = tensor # type: ignore[assignment]
elif name in module._buffers:
module._buffers[name] = tensor
else:
setattr(module, name, tensor)
def swap_tensor(
module: "torch.nn.Module",
name: str,
tensor: torch.Tensor,
allow_missing: bool = False,
) -> torch.Tensor:
if not isinstance(module, torch.nn.Module):
raise TypeError(f"{module} is not an instance of torch.nn.Module")
if (
tensor is not _MISSING
and not isinstance(tensor, torch.Tensor)
and tensor is not None
):
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
if "." in name:
raise KeyError('tensor name can\'t contain "."')
if name == "":
raise KeyError('tensor name can\'t be empty string ""')
orig_tensor: torch.Tensor
if name in module._parameters:
orig_tensor = module._parameters[name] # type: ignore[assignment]
if tensor is not _MISSING:
module._parameters[name] = tensor # type: ignore[assignment]
else:
del module._parameters[name]
elif name in module._buffers:
orig_tensor = module._buffers[name] # type: ignore[assignment]
if tensor is not _MISSING:
module._buffers[name] = tensor
else:
del module._buffers[name]
else:
try:
orig_tensor = getattr(module, name)
except AttributeError as ex:
if not allow_missing:
raise AttributeError(
f"{module._get_name()} has no attribute `{name}`"
) from ex
orig_tensor = _MISSING
if (
orig_tensor is not _MISSING
and not isinstance(orig_tensor, torch.Tensor)
and orig_tensor is not None
):
raise TypeError(
f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
)
if tensor is not _MISSING:
setattr(module, name, tensor)
elif hasattr(module, name):
delattr(module, name)
return orig_tensor
class NamedMemberAccessor:
"""
A class that provides a way to access the submodules and parameters/buffers
of a module. It provides caching mechanism to speed up submodule lookups.
This is useful for functional programming to manipulate the module state.
"""
def __init__(self, module: "torch.nn.Module") -> None:
self.module = module
self.memo: Dict[str, torch.nn.Module] = {}
# Nested attribute access
def get_submodule(self, name: str) -> "torch.nn.Module":
"""
Return the submodule specified by the given path.
For example, to get the submodule mod.layer1.conv1,
use accessor.get_submodule("layer1.conv1")
Compare to mod.get_submodule("layer1.conv1"), this method will cache the
intermediate submodule access to speed up future lookups.
"""
if not name:
return self.module
try:
return self.memo[name]
except KeyError:
prefix, dot, attr = name.rpartition(".")
if dot:
module = self.get_submodule(prefix)
else:
module = self.module
try:
submodule = getattr(module, attr)
except AttributeError as ex:
raise AttributeError(
f"{module._get_name()} has no attribute `{attr}`"
) from ex
if not isinstance(submodule, torch.nn.Module):
raise TypeError(
f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module"
)
self.memo[name] = submodule
return submodule
def get_tensor(self, name: str) -> torch.Tensor:
"""
Get the tensor specified by the given path to value.
For example, to get the attribute mod.layer1.conv1.weight,
use accessor.get_tensor('layer1.conv1.weight')
Compare to mod.get_parameter("layer1.conv1.weight"), this method will
cache the intermediate submodule access to speed up future lookups.
"""
prefix, _, attr = name.rpartition(".")
submodule = self.get_submodule(prefix)
try:
tensor = getattr(submodule, attr)
except AttributeError as ex:
raise AttributeError(
f"{submodule._get_name()} has no attribute `{name}`"
) from ex
if not isinstance(tensor, torch.Tensor) and tensor is not None:
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
return tensor # type: ignore[return-value]
def set_tensor(self, name: str, value: torch.Tensor) -> None:
"""
Set the attribute specified by the given path to value.
For example, to set the attribute mod.layer1.conv1.weight,
use accessor.set_tensor("layer1.conv1.weight", value)
"""
prefix, _, attr = name.rpartition(".")
set_tensor(self.get_submodule(prefix), attr, value)
def del_tensor(self, name: str) -> None:
"""
Delete the attribute specified by the given path.
For example, to delete the attribute mod.layer1.conv1.weight,
use accessor.del_tensor("layer1.conv1.weight")
"""
prefix, _, attr = name.rpartition(".")
submodule = self.get_submodule(prefix)
try:
delattr(submodule, attr)
except AttributeError as ex:
raise AttributeError(
f"{submodule._get_name()} has no attribute `{name}`"
) from ex
def swap_tensor(
self, name: str, value: torch.Tensor, allow_missing: bool = False
) -> torch.Tensor:
"""
Swap the attribute specified by the given path to value.
For example, to swap the attribute mod.layer1.conv1.weight,
use accessor.swap_tensor("layer1.conv1.weight", value)
"""
prefix, _, attr = name.rpartition(".")
return swap_tensor(
self.get_submodule(prefix), attr, value, allow_missing=allow_missing
)
# Batched operations
def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
"""
Get the tensors specified by the given paths.
For example, to get the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"])
"""
return [self.get_tensor(name) for name in names]
def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None:
"""
Set the attributes specified by the given paths to values.
For example, to set the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"], [weight, bias])
"""
if not isinstance(names, (list, tuple)):
names = list(names)
if not isinstance(values, (list, tuple)):
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
for name, value in zip(names, values):
self.set_tensor(name, value)
def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
"""
Set the attributes specified by the given paths to values.
For example, to set the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.set_tensors_dict({
"layer1.conv1.weight": weight,
"layer1.conv1.bias": bias,
})
"""
for name, value in named_tensors.items():
self.set_tensor(name, value)
def del_tensors(self, names: Iterable[str]) -> None:
"""
Delete the attributes specified by the given paths.
For example, to delete the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"])
"""
for name in names:
self.del_tensor(name)
def swap_tensors(
self,
names: Iterable[str],
values: Iterable[torch.Tensor],
allow_missing: bool = False,
) -> List[torch.Tensor]:
"""
Swap the attributes specified by the given paths to values.
For example, to swap the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"], [weight, bias])
"""
if not isinstance(names, (list, tuple)):
names = list(names)
if not isinstance(values, (list, tuple)):
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
return [
self.swap_tensor(name, value, allow_missing=allow_missing)
for name, value in zip(names, values)
]
def swap_tensors_dict(
self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
) -> Tuple[Dict[str, torch.Tensor], List[str]]:
"""
Swap the attributes specified by the given paths to values.
For example, to swap the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.swap_tensors_dict({
"layer1.conv1.weight": weight,
"layer1.conv1.bias": bias,
})
"""
orig_named_tensors = {}
missing_keys = []
try:
for name, tensor in named_tensors.items():
orig_tensor = self.swap_tensor(name, tensor, allow_missing=True)
if orig_tensor is _MISSING:
missing_keys.append(name)
orig_named_tensors[name] = orig_tensor
except Exception:
# Swap back if any exception occurs
for name, orig_tensor in orig_named_tensors.items():
self.swap_tensor(name, orig_tensor, allow_missing=True)
raise
if missing_keys and not allow_missing:
# Swap back if any key is missing when allow_missing is False
for name, orig_tensor in orig_named_tensors.items():
self.swap_tensor(name, orig_tensor, allow_missing=True)
raise RuntimeError(
"Missing key(s): {}.".format(", ".join(map(repr, missing_keys)))
)
return orig_named_tensors, missing_keys
def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
"""
Check that the given keys are valid.
"""
keys = set(keys)
valid_keys = set(name for name, _ in self.named_tensors(remove_duplicate=False))
missing_keys = valid_keys - keys
unexpected_keys = keys - valid_keys
return sorted(missing_keys), sorted(unexpected_keys)
# Shortcut methods
def named_parameters(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Iterate over all the parameters in the module.
"""
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
def named_buffers(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Iterate over all the buffers in the module.
"""
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
def named_tensors(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Iterate over all the tensors in the module.
"""
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
def named_modules(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, "torch.nn.Module"]]:
"""
Iterate over all the modules in the module.
"""
yield from self.module.named_modules(remove_duplicate=remove_duplicate)

View File

@ -1,17 +1,25 @@
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Set, Tuple, Union
import torch
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
__all__ = ["functional_call"]
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
# and using other types causes mypy errors
# TODO: remove this unreferenced function when `torch.nn.utils._stateless` is removed
def _change_class(module, params_and_buffers) -> None:
warnings.warn(
"The function `torch.nn.utils.stateless._change_class` is private "
"and it is deprecated now. It may be removed in a future release.",
DeprecationWarning,
)
cls = module.__class__
attr_to_path : Dict[str, str] = module._attr_to_path
attr_to_path: Dict[str, str] = module._attr_to_path
def _getattribute(self, name: str) -> Any:
if name in attr_to_path:
@ -37,144 +45,169 @@ def _change_class(module, params_and_buffers) -> None:
module._orig_class = cls
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
def _untie_named_tensors_map(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
) -> Dict[str, Tensor]:
"""
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
Unties all tied tensors in the module to parameters_and_buffers.
Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
This function returns a new untied_parameters_and_buffers dictionary and leave the original
untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors
in the module to untied_parameters_and_buffers. The value of the new key is the user-given value
in the original parameters_and_buffers dictionary.
ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
{'tied_foo': 'foo'}.
If there are more than one user-given values for the same tied tensor, it will raise an error.
ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
For example, if the module has two tied weights self.foo and self.tied_foo and the user passes
{'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the
user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the
user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error.
ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
user only passed a new value for 'bias', this looks returns: {'bias': set()}
Args:
module (torch.nn.Module): the module to determine which tensors are tied.
parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module.
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
this returned dictionary.
Returns:
A new untied version of the parameters_and_buffers dictionary.
Raises:
ValueError: if there are more than one user-given values for the same tied tensor.
"""
# A map of {name: tensor} for all tensors (including tied ones) in the module.
all_named_tensors: Dict[str, Tensor] = {}
all_named_tensors.update(module.named_parameters(remove_duplicate=False))
all_named_tensors.update(module.named_buffers(remove_duplicate=False))
# The basic algorithm looks like:
# - index all weights by their original tensor value to find tied weights
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
# A map of {tensor: set(all_tied_names)} for all tensor names in the module.
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set)
for name, tensor in all_named_tensors.items():
tensor_to_tied_names_map[tensor].add(name)
names = params_and_buffers.keys()
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
# A map of {tied_name: set(all_tied_names)} for all tensor names in the module.
# If a name is not tied, it will not be in this map.
tied_names_map: Dict[str, Set[str]] = {}
for tied_names in tensor_to_tied_names_map.values():
if len(tied_names) > 1:
for tied_name in tied_names:
tied_names_map[tied_name] = tied_names
# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
# part at the end it's (used_name, (tied_names)).
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
def add_to_name_map(n: str, t: torch.Tensor):
# if the tensor hasn't been seen before, add it to the map
if t not in weight_to_name_and_tied_names:
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
return
# Make sure the user didn't pass multiple values for the same tied tensor.
given_names = set(parameters_and_buffers.keys())
given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys())
for given_name in given_names_for_tied_tensors:
tied_names = tied_names_map[given_name]
if (
# Detect if there are multiple keys present for the same tied tensor.
len(tied_names.intersection(given_names_for_tied_tensors)) > 1
# Only raise an error if the user passed multiple values for the same tied tensor.
# If all given values are the same, don't raise.
and len({parameters_and_buffers[tied_name] for tied_name in tied_names})
!= 1
):
raise ValueError(
f"functional_call got multiple values for keys {sorted(tied_names)}, "
f"which are tied. Consider using tie_weights=False"
)
# if the name is not used by the user, we add it to the tied set
if n not in names:
weight_to_name_and_tied_names[t][1].add(n)
return
# check that the user didn't pass two different tensors for the same tied weight
first_seen_name = weight_to_name_and_tied_names[t][0]
# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
return
raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
"Consider using tie_weights=False")
tensor: Tensor
for name, tensor in module.named_parameters(remove_duplicate=False):
add_to_name_map(name, tensor)
for name, tensor in module.named_buffers(remove_duplicate=False):
add_to_name_map(name, tensor)
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
tied_weights_to_given_name = {}
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
continue
for tied_name in tied_names:
tied_weights_to_given_name[tied_name] = name_given_by_user
return tied_weights_to_given_name
def _create_swap_params(params_and_buffers):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
module._attr_to_path = {}
module._attr_to_path[tensor_name] = full_path
_change_class(module, params_and_buffers)
return _swap_parameters
def _remove_swap(module, name: str, full_path: str) -> None:
if hasattr(module, "_orig_class"):
module.__class__ = module._orig_class
delattr(module, "_orig_class")
delattr(module, "_attr_to_path")
# Untie the given named tensor map
# Make a copy for not modifying the original dict
untied_parameters_and_buffers = parameters_and_buffers.copy()
for given_name in given_names_for_tied_tensors:
for tied_name in tied_names_map[given_name]:
untied_parameters_and_buffers[tied_name] = parameters_and_buffers[
given_name
]
return untied_parameters_and_buffers
@contextlib.contextmanager
def _reparametrize_module(
module: 'torch.nn.Module',
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
*,
strict: bool = False,
) -> Iterator[None]:
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
module, name.split("."), name, (tensor,))
for tied_name, user_given_name in tied_weights_map.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
module, tied_name.split("."), user_given_name, (None,))
if tie_weights:
untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
else:
untied_parameters_and_buffers = parameters_and_buffers
accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = accessor.check_keys(
untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
"Unexpected key(s): {}.".format(", ".join(map(repr, unexpected_keys)))
)
if len(missing_keys) > 0:
error_msgs.append(
"Missing key(s): {}.".format(", ".join(map(repr, missing_keys)))
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
)
)
orig_parameters_and_buffers: Dict[str, Tensor] = {}
try:
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
untied_parameters_and_buffers, allow_missing=True
)
yield
finally:
for name in parameters_and_buffers:
_apply_func_submodules(
_remove_swap,
module, name.split("."), name, ())
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
orig_parameters_and_buffers, allow_missing=True
)
# Sometimes the module is not completely stateless and has some in-place modifications on
# the _parameters and _buffers dictionaries.
# Write the changed parameters and buffers back to the original dict.
parameters_and_buffers.update(
{
k: new_parameters_and_buffers[k]
for k in parameters_and_buffers
if k in new_parameters_and_buffers
}
)
# TODO: remove this unreferenced function when `torch.nn.utils._stateless` is removed
def _apply_func_submodules(
func: Callable[..., None],
module: 'torch.nn.Module',
module: "torch.nn.Module",
path: List[str],
full_path: str,
args: Tuple,
):
warnings.warn(
"The function `torch.nn.utils.stateless._apply_func_submodules` is private "
"and it is deprecated now. It may be removed in a future release.",
DeprecationWarning,
)
if len(path) == 1:
func(module, path[0], full_path, *args)
else:
_apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
_apply_func_submodules(
func, getattr(module, path[0]), path[1:], full_path, args
)
def functional_call(
module: 'torch.nn.Module',
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
@ -229,6 +262,9 @@ def functional_call(
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
error. Default: False.
Returns:
Any: the result of calling ``module``.
@ -236,35 +272,47 @@ def functional_call(
warnings.warn(
"This API is deprecated as of PyTorch 2.0 and will be removed in a future "
"version of PyTorch. Please use torch.func.functional_call instead "
"which is a drop-in replacement for this API.")
"which is a drop-in replacement for this API."
)
return _functional_call(
module,
parameters_and_buffers,
args,
kwargs,
tie_weights=tie_weights,
strict=strict,
)
return _functional_call(module, parameters_and_buffers, args, kwargs,
tie_weights=tie_weights)
def _functional_call(
module: 'torch.nn.Module',
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
# TODO allow kwargs such as unsafe and others for parametrization
if (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or isinstance(module, (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or isinstance(
module,
(
torch.jit.RecursiveScriptModule,
torch.jit.ScriptModule,
torch.jit.ScriptFunction)
)
torch.jit.ScriptFunction,
),
)
):
raise RuntimeError("The stateless API can't be used with Jitted modules")
if kwargs is None:
kwargs = {}
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else:
out = module(args, **kwargs)
return out
if not isinstance(args, tuple):
args = (args,)
with _reparametrize_module(
module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
):
return module(*args, **kwargs)