mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch][nn] Refactor NN stateless APIs by swapping module tensors (#92536)
- Fixes #92295 - Resolves #86708 - Resolves #92153 - Closes #92401 - Closes #92218 - Requires #91579 Refactor NN stateless APIs by swapping module tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92536 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
3fd46a2f9c
commit
b8de1cf007
@ -835,9 +835,10 @@ include_patterns = [
|
||||
'torch/_*.py',
|
||||
'torch/testing/_internal/opinfo/**/*.py',
|
||||
'torchgen/**/*.py',
|
||||
'functorch/functorch/_src/aot_autograd.py',
|
||||
'functorch/functorch/_src/compilers.py',
|
||||
'torch/_functorch/make_functional.py',
|
||||
'torch/_functorch/functional_call.py',
|
||||
'torch/nn/utils/_named_member_accessor.py',
|
||||
'torch/nn/utils/stateless.py',
|
||||
'torch/testing/*.py',
|
||||
'torch/distributed/fsdp/**/*.py',
|
||||
'test/distributed/fsdp/**/*.py',
|
||||
|
@ -1,12 +1,13 @@
|
||||
# Owner(s): ["module: nn"]
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn.utils.stateless as stateless
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase, parametrize, instantiate_parametrized_tests, \
|
||||
@ -18,10 +19,12 @@ class MockModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.foo = 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.buffer
|
||||
|
||||
|
||||
class MockTiedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -65,6 +68,29 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
self.assertEqual(cur_weight, prev_weight)
|
||||
self.assertEqual(cur_buffer, prev_buffer)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _ensure_module_unchanged(self, module, message):
|
||||
orig_parameters, orig_buffers = tuple(module.parameters()), tuple(module.buffers())
|
||||
orig_tensors = orig_parameters + orig_buffers
|
||||
orig_tensors_values = tuple(t.clone() for t in orig_tensors)
|
||||
try:
|
||||
yield module
|
||||
finally:
|
||||
parameters, buffers = tuple(module.parameters()), tuple(module.buffers())
|
||||
self.assertTrue(
|
||||
len(parameters) == len(orig_parameters)
|
||||
and len(buffers) == len(orig_buffers)
|
||||
and all(
|
||||
t1 is t2 and torch.allclose(t1, t3)
|
||||
for t1, t2, t3 in zip(
|
||||
orig_tensors,
|
||||
parameters + buffers,
|
||||
orig_tensors_values,
|
||||
)
|
||||
),
|
||||
message,
|
||||
)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
@ -201,7 +227,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_module_fail_reset_to_original(self, functional_call):
|
||||
def test_reparametrize_module_fail_reset_to_original(self, functional_call):
|
||||
module = MockModule()
|
||||
torch.nn.utils.parametrizations.spectral_norm(module.l1)
|
||||
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
|
||||
@ -220,6 +246,161 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
|
||||
self.assertEqual(orig_sn_weight, module.l1.weight)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparametrize_some_weights(self, functional_call):
|
||||
module = MockModule()
|
||||
weight = torch.tensor([[2.0]])
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
extra = torch.tensor([1.0])
|
||||
|
||||
parameters = {'l1.weight': weight}
|
||||
x = torch.randn(1, 1)
|
||||
out = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
out = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparametrize_strict(self, functional_call):
|
||||
module = MockModule()
|
||||
weight = torch.tensor([[2.0]])
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
extra = torch.tensor([1.0])
|
||||
|
||||
# All weights no error
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a successful call',
|
||||
):
|
||||
out = functional_call(module, parameters, x, strict=True)
|
||||
self.assertEqual(out, x * weight + bias + buffer)
|
||||
|
||||
# Some weights
|
||||
parameters = {'l1.weight': weight}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Missing key(s): 'buffer', 'l1.bias'."),
|
||||
):
|
||||
out = functional_call(module, parameters, x, strict=True)
|
||||
|
||||
# Extra keys
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Unexpected key(s): 'extra'."),
|
||||
):
|
||||
out = functional_call(module, parameters, x, strict=True)
|
||||
|
||||
# Some weights with extra keys
|
||||
parameters = {'l1.weight': weight,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'buffer', 'l1.bias'."),
|
||||
):
|
||||
out = functional_call(module, parameters, x, strict=True)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparametrize_special(self, functional_call):
|
||||
class NonTensor:
|
||||
def __repr__(self):
|
||||
return f'<{self.__class__.__name__}>'
|
||||
|
||||
module = MockModule()
|
||||
weight = torch.tensor([[2.0]])
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
non_tensor = NonTensor()
|
||||
|
||||
# Set to None
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': None,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a successful call',
|
||||
):
|
||||
out = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + buffer)
|
||||
|
||||
# Set non-tensor
|
||||
parameters = {'l1.weight': non_tensor}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.escape("<NonTensor> is not an instance of torch.Tensor"),
|
||||
):
|
||||
out = functional_call(module, parameters, x)
|
||||
|
||||
# Set non-tensor attribute
|
||||
parameters = {'l1.weight': weight, 'foo': torch.tensor([1.0])}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.escape("attribute `foo`: 0.0 is not an instance of torch.Tensor"),
|
||||
):
|
||||
out = functional_call(module, parameters, x)
|
||||
|
||||
# Set non-exist submodule
|
||||
parameters = {'l1.weight': weight,
|
||||
'l2.bias': bias}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
re.escape("MockModule has no attribute `l2`"),
|
||||
):
|
||||
out = functional_call(module, parameters, x)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
@ -233,11 +414,12 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_tie_weights(self, functional_call):
|
||||
def test_reparametrize_tie_weights(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]],)
|
||||
weight = torch.tensor([[2.0]])
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
extra = torch.tensor([1.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
@ -246,14 +428,21 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
out = functional_call(module, parameters, x, tie_weights=True)
|
||||
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
out = functional_call(module, parameters, x, tie_weights=True)
|
||||
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_tie_some_weights(self, functional_call):
|
||||
def test_reparametrize_tie_some_weights(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]],)
|
||||
weight = torch.tensor([[2.0]])
|
||||
buffer = torch.tensor([3.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
@ -268,7 +457,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
])
|
||||
def test_tied_weights_errors(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[1.0]],)
|
||||
weight = torch.tensor([[1.0]])
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
|
||||
@ -285,19 +474,24 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
del parameters['tied_bias']
|
||||
del parameters['tied_buffer']
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("functional_call got multiple values for keys ['l1.bias', 'tied_bias']"),
|
||||
):
|
||||
parameters['tied_bias'] = torch.tensor([5.0])
|
||||
functional_call(module, parameters, x, tie_weights=True)
|
||||
del parameters['tied_bias']
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("functional_call got multiple values for keys ['buffer', 'tied_buffer']"),
|
||||
):
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
functional_call(module, parameters, x, tie_weights=True)
|
||||
|
||||
|
||||
def test_tied_weights_no_error_without_flag(self):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[1.0]],)
|
||||
weight = torch.tensor([[1.0]])
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
|
||||
@ -312,6 +506,105 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
self.assertNotWarn(lambda: stateless._functional_call(module, parameters, x, tie_weights=False))
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparametrize_tie_weights_strict(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]])
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
extra = torch.tensor([1.0])
|
||||
|
||||
# Tie weights no error
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a successful call',
|
||||
):
|
||||
out = functional_call(module, parameters, x, tie_weights=True, strict=True)
|
||||
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
|
||||
|
||||
# Tie weights without flag
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
|
||||
):
|
||||
out = functional_call(module, parameters, x, tie_weights=False, strict=True)
|
||||
|
||||
# Tie some weights
|
||||
parameters = {'l1.weight': weight,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
|
||||
):
|
||||
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
|
||||
|
||||
# Tie weights with extra keys
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Unexpected key(s): 'extra'."),
|
||||
):
|
||||
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
|
||||
|
||||
# Tie weights with extra keys and without flag
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'tied_bias', 'tied_buffer'."),
|
||||
):
|
||||
out = stateless.functional_call(module, parameters, x, tie_weights=False, strict=True)
|
||||
|
||||
# Tie some weights with extra keys
|
||||
parameters = {'l1.weight': weight,
|
||||
'buffer': buffer,
|
||||
'extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
with self._ensure_module_unchanged(
|
||||
module,
|
||||
'the module should not have been modified by a failed call',
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Unexpected key(s): 'extra'.") + r'\s+' + re.escape("Missing key(s): 'l1.bias', 'tied_bias'."),
|
||||
):
|
||||
out = stateless.functional_call(module, parameters, x, tie_weights=True, strict=True)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
@ -320,17 +613,89 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('foo', torch.zeros(()))
|
||||
self.register_buffer('foo', torch.tensor([0.0]))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo = self.foo + 1
|
||||
return x + self.foo
|
||||
|
||||
a = {'foo': torch.zeros(())}
|
||||
foo = torch.tensor([2.0])
|
||||
x = torch.randn(1)
|
||||
a = {'foo': foo}
|
||||
mod = Foo()
|
||||
functional_call(mod, a, torch.ones(()))
|
||||
self.assertEqual(mod.foo, torch.zeros(()))
|
||||
self.assertEqual(a['foo'], torch.ones(()))
|
||||
functional_call(mod, a, x)
|
||||
self.assertEqual(mod.foo, torch.tensor([0.0]))
|
||||
self.assertEqual(a['foo'], torch.tensor([3.0]))
|
||||
self.assertEqual(foo, torch.tensor([2.0]))
|
||||
self.assertTrue(a['foo'] is not foo)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_in_place_operator(self, functional_call):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('foo', torch.tensor([0.0]))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo.add_(1)
|
||||
return x + self.foo
|
||||
|
||||
foo = torch.tensor([2.0])
|
||||
x = torch.randn(1)
|
||||
a = {'foo': foo}
|
||||
mod = Foo()
|
||||
functional_call(mod, a, x)
|
||||
self.assertEqual(mod.foo, torch.tensor([0.0]))
|
||||
self.assertEqual(a['foo'], torch.tensor([3.0]))
|
||||
self.assertEqual(foo, torch.tensor([3.0]))
|
||||
self.assertTrue(a['foo'] is foo)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_setattr_strict(self, functional_call):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert not hasattr(self, 'extra')
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.extra
|
||||
|
||||
a = {'extra': torch.zeros(())}
|
||||
mod = Bar()
|
||||
self.assertTrue(not hasattr(mod, 'extra'))
|
||||
out = functional_call(mod, a, torch.ones(()))
|
||||
self.assertEqual(out, torch.ones(()))
|
||||
self.assertTrue(not hasattr(mod, 'extra'))
|
||||
|
||||
a = {'extra': torch.zeros(())}
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Unexpected key(s): 'extra'."),
|
||||
):
|
||||
out = functional_call(mod, a, torch.ones(()), strict=True)
|
||||
self.assertTrue(not hasattr(mod, 'extra'))
|
||||
|
||||
a = {}
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
re.escape("'Bar' object has no attribute 'extra'"),
|
||||
):
|
||||
out = functional_call(mod, a, torch.ones(()))
|
||||
self.assertTrue(not hasattr(mod, 'extra'))
|
||||
|
||||
a = {}
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
re.escape("'Bar' object has no attribute 'extra'"),
|
||||
):
|
||||
out = functional_call(mod, a, torch.ones(()), strict=True)
|
||||
self.assertTrue(not hasattr(mod, 'extra'))
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
@ -355,7 +720,6 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp})
|
||||
self.assertEqual(res, res_1)
|
||||
|
||||
|
||||
def test_functional_call_tuple_dicts(self):
|
||||
mod = MockModule()
|
||||
x = torch.rand((1, 1))
|
||||
@ -375,15 +739,121 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
res = torch.func.functional_call(mod, a, x)
|
||||
self.assertEqual(res, x + 1)
|
||||
|
||||
|
||||
def test_functional_call_multiple_dicts_error(self):
|
||||
mod = MockModule()
|
||||
x = torch.rand((1, 1))
|
||||
parameters = {'l1.weight': torch.zeros((1, 1)), 'l1.bias': torch.zeros((1, 1))}
|
||||
repeated_parameters = {'l1.weight': torch.ones((1, 1))}
|
||||
with self.assertRaisesRegex(ValueError, "l1.weight appeared in multiple dictionaries"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("['l1.weight'] appeared in multiple dictionaries"),
|
||||
):
|
||||
torch.func.functional_call(mod, (parameters, repeated_parameters), x)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_functional_call_member_reference(self, functional_call):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
parameters = tuple(self.parameters())
|
||||
buffers = tuple(self.buffers())
|
||||
return self.l1(x) + self.buffer, parameters, buffers
|
||||
|
||||
module = Module()
|
||||
weight = torch.tensor([[2.0]])
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
extra = torch.tensor([1.0])
|
||||
extra_p = torch.nn.Parameter(extra)
|
||||
|
||||
# All weights
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + bias + buffer)
|
||||
self.assertEqual(parameters, (weight, bias))
|
||||
self.assertEqual(buffers, (buffer,))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
|
||||
|
||||
# Some weights
|
||||
parameters = {'l1.weight': weight}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
|
||||
self.assertEqual(parameters, (weight, module.l1.bias))
|
||||
self.assertEqual(buffers, (module.buffer,))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
|
||||
|
||||
# All weights with extra keys
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer,
|
||||
'l1.extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + bias + buffer)
|
||||
self.assertEqual(parameters, (weight, bias))
|
||||
self.assertEqual(buffers, (buffer,))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
|
||||
|
||||
# All weights with extra keys with parameters
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer,
|
||||
'l1.extra': extra_p}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + bias + buffer)
|
||||
self.assertEqual(parameters, (weight, bias, extra_p))
|
||||
self.assertEqual(buffers, (buffer,))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, bias, extra_p))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (buffer,))))
|
||||
|
||||
# Some weights with extra keys
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.extra': extra}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
|
||||
self.assertEqual(parameters, (weight, module.l1.bias))
|
||||
self.assertEqual(buffers, (module.buffer))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
|
||||
|
||||
# Some weights with extra keys with parameters
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.extra': extra_p}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + module.l1.bias + module.buffer)
|
||||
self.assertEqual(parameters, (weight, module.l1.bias, extra_p))
|
||||
self.assertEqual(buffers, (module.buffer))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight, module.l1.bias, extra_p))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
|
||||
|
||||
# Set None
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': None}
|
||||
x = torch.randn(1, 1)
|
||||
out, parameters, buffers = functional_call(module, parameters, x)
|
||||
self.assertEqual(out, x * weight + module.buffer)
|
||||
self.assertEqual(parameters, (weight,))
|
||||
self.assertEqual(buffers, (module.buffer))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(parameters, (weight,))))
|
||||
self.assertTrue(all(t1 is t2 for t1, t2 in zip(buffers, (module.buffer,))))
|
||||
|
||||
|
||||
class TestStatelessDeprecation(TestCase):
|
||||
def test_private_stateless_warns(self):
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Dict, Union, Any, Tuple, List
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -8,12 +9,13 @@ from torch._functorch.utils import exposed_in
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def functional_call(
|
||||
module: 'torch.nn.Module',
|
||||
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
|
||||
module: "torch.nn.Module",
|
||||
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
):
|
||||
r"""Performs a functional call on the module by replacing the module parameters
|
||||
and buffers with the provided ones.
|
||||
@ -100,7 +102,7 @@ def functional_call(
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the module to call
|
||||
parameters_and_buffers (Dict[str,Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
|
||||
parameters_and_buffers (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
|
||||
the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
|
||||
be used together
|
||||
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
|
||||
@ -109,25 +111,49 @@ def functional_call(
|
||||
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
|
||||
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
|
||||
buffers unless the values passed for both weights are the same. Default: True.
|
||||
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
|
||||
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
|
||||
error. Default: False.
|
||||
|
||||
Returns:
|
||||
Any: the result of calling ``module``.
|
||||
"""
|
||||
parameters_and_buffers = parameter_and_buffer_dicts if isinstance(parameter_and_buffer_dicts, dict) else {}
|
||||
if isinstance(parameter_and_buffer_dicts, tuple):
|
||||
key_list = [i for dct in parameter_and_buffer_dicts for i in dct.keys()]
|
||||
key_set = set(key_list)
|
||||
if len(key_set) != len(key_list):
|
||||
repeated_key = list(filter(lambda key: key_list.count(key) > 1, key_set))[0]
|
||||
raise ValueError(f"{repeated_key} appeared in multiple dictionaries; behavior of functional call is ambiguous")
|
||||
if isinstance(parameter_and_buffer_dicts, dict):
|
||||
parameters_and_buffers = parameter_and_buffer_dicts
|
||||
elif isinstance(parameter_and_buffer_dicts, Sequence):
|
||||
if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
|
||||
raise ValueError(
|
||||
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
|
||||
)
|
||||
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
|
||||
repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1]
|
||||
if len(repeated_keys) > 0:
|
||||
raise ValueError(
|
||||
f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
|
||||
)
|
||||
parameters_and_buffers = {
|
||||
k: v for d in parameter_and_buffer_dicts for k, v in d.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
|
||||
f"but got {type(parameter_and_buffer_dicts)}"
|
||||
)
|
||||
|
||||
parameters_and_buffers = {k: v for d in parameter_and_buffer_dicts for k, v in d.items()}
|
||||
|
||||
return nn.utils.stateless._functional_call(module, parameters_and_buffers, args, kwargs, tie_weights=tie_weights)
|
||||
return nn.utils.stateless._functional_call(
|
||||
module,
|
||||
parameters_and_buffers,
|
||||
args,
|
||||
kwargs,
|
||||
tie_weights=tie_weights,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def stack_module_state(models: List[nn.Module]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def stack_module_state(
|
||||
models: List[nn.Module],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""stack_module_state(models) -> params, buffers
|
||||
|
||||
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
|
||||
@ -183,29 +209,39 @@ def stack_module_state(models: List[nn.Module]) -> Tuple[Dict[str, Any], Dict[st
|
||||
same mode (training vs eval).
|
||||
"""
|
||||
if len(models) == 0:
|
||||
raise RuntimeError('stack_module_state: Expected at least one model, got 0.')
|
||||
raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
|
||||
if not (all(m.training for m in models) or all(not m.training for m in models)):
|
||||
raise RuntimeError('stack_module_state: Expected all models to '
|
||||
'have the same training/eval mode.')
|
||||
raise RuntimeError(
|
||||
"stack_module_state: Expected all models to have the same training/eval mode."
|
||||
)
|
||||
model0_typ = type(models[0])
|
||||
if not all(type(m) == model0_typ for m in models):
|
||||
raise RuntimeError('stack_module_state: Expected all models to '
|
||||
'be of the same class.')
|
||||
all_params = [{k: v for k, v in model.named_parameters()} for model in models]
|
||||
params = {k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
|
||||
for k in all_params[0]}
|
||||
all_buffers = [{k: v for k, v in model.named_buffers()} for model in models]
|
||||
buffers = {k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
|
||||
for k in all_buffers[0]}
|
||||
raise RuntimeError(
|
||||
"stack_module_state: Expected all models to be of the same class."
|
||||
)
|
||||
all_params = [dict(model.named_parameters()) for model in models]
|
||||
params = {
|
||||
k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
|
||||
for k in all_params[0]
|
||||
}
|
||||
all_buffers = [dict(model.named_buffers()) for model in models]
|
||||
buffers = {
|
||||
k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
|
||||
for k in all_buffers[0]
|
||||
}
|
||||
|
||||
return params, buffers
|
||||
|
||||
def construct_stacked_leaf(tensors, name):
|
||||
all_requires_grad = all([t.requires_grad for t in tensors])
|
||||
none_requires_grad = all([not t.requires_grad for t in tensors])
|
||||
|
||||
def construct_stacked_leaf(
|
||||
tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
|
||||
) -> Tensor:
|
||||
all_requires_grad = all(t.requires_grad for t in tensors)
|
||||
none_requires_grad = all(not t.requires_grad for t in tensors)
|
||||
if not all_requires_grad and not none_requires_grad:
|
||||
raise RuntimeError(
|
||||
f'Expected {name} from each model to have the same .requires_grad')
|
||||
f"Expected {name} from each model to have the same .requires_grad"
|
||||
)
|
||||
result = torch.stack(tensors)
|
||||
if all_requires_grad:
|
||||
result = result.detach().requires_grad_()
|
||||
|
@ -21,43 +21,13 @@ from typing import (
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
# the parameters and evaluate the nn.Module using fixed inputs.
|
||||
|
||||
|
||||
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
||||
"""
|
||||
Deletes the attribute specified by the given list of names.
|
||||
For example, to delete the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'])
|
||||
"""
|
||||
if len(names) == 1:
|
||||
delattr(obj, names[0])
|
||||
else:
|
||||
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
|
||||
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given list of names to value.
|
||||
For example, to set the attribute obj.conv.weight,
|
||||
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
||||
"""
|
||||
if len(names) == 1:
|
||||
setattr(obj, names[0], value)
|
||||
else:
|
||||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
|
||||
def _get_nested_attr(obj: nn.Module, names: List[str]) -> Tensor:
|
||||
if len(names) == 1:
|
||||
return getattr(obj, names[0])
|
||||
else:
|
||||
return _get_nested_attr(getattr(obj, names[0]), names[1:])
|
||||
|
||||
|
||||
def raise_parameter_tying_error() -> NoReturn:
|
||||
raise RuntimeError(
|
||||
"make_functional(module): we don't yet support models that "
|
||||
@ -71,14 +41,14 @@ def raise_parameter_tying_error() -> NoReturn:
|
||||
def create_names_map(
|
||||
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
|
||||
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
named_params is a dictionary of tensors: {'A': A, 'B': B}
|
||||
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
|
||||
with potentially tied (or 'duplicated') tensors
|
||||
|
||||
This function creates a mapping from the names in named_params to the
|
||||
names in tied_named_params: {'A': [['A']], 'B': [['B'], ['B_tied']]}.
|
||||
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
|
||||
"""
|
||||
named_params = dict(named_params)
|
||||
tied_named_params = dict(tied_named_params)
|
||||
@ -87,12 +57,12 @@ def create_names_map(
|
||||
tied_tensors_dict_keys = set(tied_named_params.keys())
|
||||
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
|
||||
|
||||
tensor_to_mapping: Dict[Tensor, Tuple[str, List[List[str]]]] = {}
|
||||
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
|
||||
for key, tensor in named_params.items():
|
||||
tensor_to_mapping[tensor] = (key, [])
|
||||
for key, tensor in tied_named_params.items():
|
||||
assert tensor in tensor_to_mapping
|
||||
tensor_to_mapping[tensor][1].append(key.split("."))
|
||||
tensor_to_mapping[tensor][1].append(key)
|
||||
return dict(tensor_to_mapping.values())
|
||||
|
||||
|
||||
@ -100,18 +70,19 @@ def _extract_members(
|
||||
mod: nn.Module,
|
||||
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
|
||||
subclass: Callable[[Tensor], Tensor],
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
|
||||
all_named_members = tuple(named_members(remove_duplicate=False))
|
||||
unique_named_members = tuple(named_members(remove_duplicate=True))
|
||||
names_map = create_names_map(unique_named_members, all_named_members)
|
||||
|
||||
# Remove all the members in the model
|
||||
memo = {}
|
||||
accessor = NamedMemberAccessor(mod)
|
||||
for name, p in all_named_members:
|
||||
if p not in memo:
|
||||
memo[p] = subclass(torch.empty_like(p, device="meta"))
|
||||
replacement = memo[p]
|
||||
_set_nested_attr(mod, name.split("."), replacement)
|
||||
accessor.set_tensor(name, replacement)
|
||||
|
||||
if len(unique_named_members) == 0:
|
||||
names, params = (), ()
|
||||
@ -122,7 +93,7 @@ def _extract_members(
|
||||
|
||||
def extract_weights(
|
||||
mod: nn.Module,
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
@ -136,7 +107,7 @@ def extract_weights(
|
||||
|
||||
def extract_buffers(
|
||||
mod: nn.Module,
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
|
||||
return _extract_members(mod, mod.named_buffers, lambda x: x)
|
||||
|
||||
|
||||
@ -151,23 +122,23 @@ def load_weights(
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
as Tensors. This means that mod.parameters() will still be empty after this call.
|
||||
"""
|
||||
for name, p in zip(names, params):
|
||||
if as_params:
|
||||
p = nn.Parameter(p)
|
||||
_del_nested_attr(mod, name.split("."))
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
accessor = NamedMemberAccessor(mod)
|
||||
if as_params:
|
||||
params = [nn.Parameter(p) for p in params]
|
||||
accessor.set_tensors(names, params)
|
||||
|
||||
|
||||
def _swap_state(
|
||||
mod: nn.Module, names_map: Dict[str, List[List[str]]], elems: Iterable[Tensor]
|
||||
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
|
||||
) -> List[Tensor]:
|
||||
result: List[Tensor] = []
|
||||
accessor = NamedMemberAccessor(mod)
|
||||
for (_, attr_names), elem in zip(names_map.items(), elems):
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
if i == 0:
|
||||
result.append(_get_nested_attr(mod, attr_name))
|
||||
_del_nested_attr(mod, attr_name)
|
||||
_set_nested_attr(mod, attr_name, elem)
|
||||
result.append(accessor.swap_tensor(attr_name, elem))
|
||||
else:
|
||||
accessor.set_tensor(attr_name, elem)
|
||||
return result
|
||||
|
||||
|
||||
@ -177,8 +148,8 @@ def load_buffers(
|
||||
buffers: Sequence[Tensor],
|
||||
as_params: bool = False,
|
||||
) -> None:
|
||||
for name, p in zip(names, buffers):
|
||||
_set_nested_attr(mod, name.split("."), p)
|
||||
accessor = NamedMemberAccessor(mod)
|
||||
accessor.set_tensors(names, buffers)
|
||||
|
||||
|
||||
def load_state(
|
||||
@ -290,8 +261,8 @@ class FunctionalModuleWithBuffers(nn.Module):
|
||||
stateless_model: nn.Module,
|
||||
param_names: Tuple[str, ...],
|
||||
buffer_names: Tuple[str, ...],
|
||||
param_names_map: Dict[str, List[List[str]]],
|
||||
buffer_names_map: Dict[str, List[List[str]]],
|
||||
param_names_map: Dict[str, List[str]],
|
||||
buffer_names_map: Dict[str, List[str]],
|
||||
) -> None:
|
||||
super(FunctionalModuleWithBuffers, self).__init__()
|
||||
self.stateless_model = stateless_model
|
||||
@ -345,7 +316,7 @@ class FunctionalModule(nn.Module):
|
||||
self,
|
||||
stateless_model: nn.Module,
|
||||
param_names: Tuple[str, ...],
|
||||
names_map: Dict[str, List[List[str]]],
|
||||
names_map: Dict[str, List[str]],
|
||||
) -> None:
|
||||
super(FunctionalModule, self).__init__()
|
||||
self.stateless_model = stateless_model
|
||||
@ -567,8 +538,7 @@ def combine_state_for_ensemble(
|
||||
model0_typ = type(models[0])
|
||||
if not all(type(m) == model0_typ for m in models):
|
||||
raise RuntimeError(
|
||||
"combine_state_for_ensemble: Expected all models to "
|
||||
"be of the same class."
|
||||
"combine_state_for_ensemble: Expected all models to be of the same class."
|
||||
)
|
||||
funcs, params, buffers = zip(
|
||||
*[make_functional_with_buffers(model) for model in models]
|
||||
|
341
torch/nn/utils/_named_member_accessor.py
Normal file
341
torch/nn/utils/_named_member_accessor.py
Normal file
@ -0,0 +1,341 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_MISSING: torch.Tensor = object() # type: ignore[assignment]
|
||||
|
||||
|
||||
def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None:
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise TypeError(f"{module} is not an instance of torch.nn.Module")
|
||||
if not isinstance(tensor, torch.Tensor) and tensor is not None:
|
||||
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
|
||||
if "." in name:
|
||||
raise KeyError('tensor name can\'t contain "."')
|
||||
if name == "":
|
||||
raise KeyError('tensor name can\'t be empty string ""')
|
||||
if name in module._parameters:
|
||||
module._parameters[name] = tensor # type: ignore[assignment]
|
||||
elif name in module._buffers:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
setattr(module, name, tensor)
|
||||
|
||||
|
||||
def swap_tensor(
|
||||
module: "torch.nn.Module",
|
||||
name: str,
|
||||
tensor: torch.Tensor,
|
||||
allow_missing: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise TypeError(f"{module} is not an instance of torch.nn.Module")
|
||||
if (
|
||||
tensor is not _MISSING
|
||||
and not isinstance(tensor, torch.Tensor)
|
||||
and tensor is not None
|
||||
):
|
||||
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
|
||||
if "." in name:
|
||||
raise KeyError('tensor name can\'t contain "."')
|
||||
if name == "":
|
||||
raise KeyError('tensor name can\'t be empty string ""')
|
||||
|
||||
orig_tensor: torch.Tensor
|
||||
if name in module._parameters:
|
||||
orig_tensor = module._parameters[name] # type: ignore[assignment]
|
||||
if tensor is not _MISSING:
|
||||
module._parameters[name] = tensor # type: ignore[assignment]
|
||||
else:
|
||||
del module._parameters[name]
|
||||
elif name in module._buffers:
|
||||
orig_tensor = module._buffers[name] # type: ignore[assignment]
|
||||
if tensor is not _MISSING:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
del module._buffers[name]
|
||||
else:
|
||||
try:
|
||||
orig_tensor = getattr(module, name)
|
||||
except AttributeError as ex:
|
||||
if not allow_missing:
|
||||
raise AttributeError(
|
||||
f"{module._get_name()} has no attribute `{name}`"
|
||||
) from ex
|
||||
orig_tensor = _MISSING
|
||||
if (
|
||||
orig_tensor is not _MISSING
|
||||
and not isinstance(orig_tensor, torch.Tensor)
|
||||
and orig_tensor is not None
|
||||
):
|
||||
raise TypeError(
|
||||
f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
|
||||
)
|
||||
if tensor is not _MISSING:
|
||||
setattr(module, name, tensor)
|
||||
elif hasattr(module, name):
|
||||
delattr(module, name)
|
||||
return orig_tensor
|
||||
|
||||
|
||||
class NamedMemberAccessor:
|
||||
"""
|
||||
A class that provides a way to access the submodules and parameters/buffers
|
||||
of a module. It provides caching mechanism to speed up submodule lookups.
|
||||
This is useful for functional programming to manipulate the module state.
|
||||
"""
|
||||
|
||||
def __init__(self, module: "torch.nn.Module") -> None:
|
||||
self.module = module
|
||||
self.memo: Dict[str, torch.nn.Module] = {}
|
||||
|
||||
# Nested attribute access
|
||||
|
||||
def get_submodule(self, name: str) -> "torch.nn.Module":
|
||||
"""
|
||||
Return the submodule specified by the given path.
|
||||
For example, to get the submodule mod.layer1.conv1,
|
||||
use accessor.get_submodule("layer1.conv1")
|
||||
|
||||
Compare to mod.get_submodule("layer1.conv1"), this method will cache the
|
||||
intermediate submodule access to speed up future lookups.
|
||||
"""
|
||||
if not name:
|
||||
return self.module
|
||||
|
||||
try:
|
||||
return self.memo[name]
|
||||
except KeyError:
|
||||
prefix, dot, attr = name.rpartition(".")
|
||||
if dot:
|
||||
module = self.get_submodule(prefix)
|
||||
else:
|
||||
module = self.module
|
||||
try:
|
||||
submodule = getattr(module, attr)
|
||||
except AttributeError as ex:
|
||||
raise AttributeError(
|
||||
f"{module._get_name()} has no attribute `{attr}`"
|
||||
) from ex
|
||||
if not isinstance(submodule, torch.nn.Module):
|
||||
raise TypeError(
|
||||
f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module"
|
||||
)
|
||||
self.memo[name] = submodule
|
||||
return submodule
|
||||
|
||||
def get_tensor(self, name: str) -> torch.Tensor:
|
||||
"""
|
||||
Get the tensor specified by the given path to value.
|
||||
For example, to get the attribute mod.layer1.conv1.weight,
|
||||
use accessor.get_tensor('layer1.conv1.weight')
|
||||
|
||||
Compare to mod.get_parameter("layer1.conv1.weight"), this method will
|
||||
cache the intermediate submodule access to speed up future lookups.
|
||||
"""
|
||||
prefix, _, attr = name.rpartition(".")
|
||||
submodule = self.get_submodule(prefix)
|
||||
try:
|
||||
tensor = getattr(submodule, attr)
|
||||
except AttributeError as ex:
|
||||
raise AttributeError(
|
||||
f"{submodule._get_name()} has no attribute `{name}`"
|
||||
) from ex
|
||||
if not isinstance(tensor, torch.Tensor) and tensor is not None:
|
||||
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
|
||||
return tensor # type: ignore[return-value]
|
||||
|
||||
def set_tensor(self, name: str, value: torch.Tensor) -> None:
|
||||
"""
|
||||
Set the attribute specified by the given path to value.
|
||||
For example, to set the attribute mod.layer1.conv1.weight,
|
||||
use accessor.set_tensor("layer1.conv1.weight", value)
|
||||
"""
|
||||
prefix, _, attr = name.rpartition(".")
|
||||
set_tensor(self.get_submodule(prefix), attr, value)
|
||||
|
||||
def del_tensor(self, name: str) -> None:
|
||||
"""
|
||||
Delete the attribute specified by the given path.
|
||||
For example, to delete the attribute mod.layer1.conv1.weight,
|
||||
use accessor.del_tensor("layer1.conv1.weight")
|
||||
"""
|
||||
prefix, _, attr = name.rpartition(".")
|
||||
submodule = self.get_submodule(prefix)
|
||||
try:
|
||||
delattr(submodule, attr)
|
||||
except AttributeError as ex:
|
||||
raise AttributeError(
|
||||
f"{submodule._get_name()} has no attribute `{name}`"
|
||||
) from ex
|
||||
|
||||
def swap_tensor(
|
||||
self, name: str, value: torch.Tensor, allow_missing: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Swap the attribute specified by the given path to value.
|
||||
For example, to swap the attribute mod.layer1.conv1.weight,
|
||||
use accessor.swap_tensor("layer1.conv1.weight", value)
|
||||
"""
|
||||
prefix, _, attr = name.rpartition(".")
|
||||
return swap_tensor(
|
||||
self.get_submodule(prefix), attr, value, allow_missing=allow_missing
|
||||
)
|
||||
|
||||
# Batched operations
|
||||
|
||||
def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
|
||||
"""
|
||||
Get the tensors specified by the given paths.
|
||||
For example, to get the attributes mod.layer1.conv1.weight and
|
||||
mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight",
|
||||
"layer1.conv1.bias"])
|
||||
"""
|
||||
return [self.get_tensor(name) for name in names]
|
||||
|
||||
def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None:
|
||||
"""
|
||||
Set the attributes specified by the given paths to values.
|
||||
For example, to set the attributes mod.layer1.conv1.weight and
|
||||
mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight",
|
||||
"layer1.conv1.bias"], [weight, bias])
|
||||
"""
|
||||
if not isinstance(names, (list, tuple)):
|
||||
names = list(names)
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = list(values)
|
||||
assert len(names) == len(values), "names and values must have the same length"
|
||||
|
||||
for name, value in zip(names, values):
|
||||
self.set_tensor(name, value)
|
||||
|
||||
def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Set the attributes specified by the given paths to values.
|
||||
For example, to set the attributes mod.layer1.conv1.weight and
|
||||
mod.layer1.conv1.bias, use accessor.set_tensors_dict({
|
||||
"layer1.conv1.weight": weight,
|
||||
"layer1.conv1.bias": bias,
|
||||
})
|
||||
"""
|
||||
for name, value in named_tensors.items():
|
||||
self.set_tensor(name, value)
|
||||
|
||||
def del_tensors(self, names: Iterable[str]) -> None:
|
||||
"""
|
||||
Delete the attributes specified by the given paths.
|
||||
For example, to delete the attributes mod.layer1.conv1.weight and
|
||||
mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight",
|
||||
"layer1.conv1.bias"])
|
||||
"""
|
||||
for name in names:
|
||||
self.del_tensor(name)
|
||||
|
||||
def swap_tensors(
|
||||
self,
|
||||
names: Iterable[str],
|
||||
values: Iterable[torch.Tensor],
|
||||
allow_missing: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Swap the attributes specified by the given paths to values.
|
||||
For example, to swap the attributes mod.layer1.conv1.weight and
|
||||
mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight",
|
||||
"layer1.conv1.bias"], [weight, bias])
|
||||
"""
|
||||
if not isinstance(names, (list, tuple)):
|
||||
names = list(names)
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = list(values)
|
||||
assert len(names) == len(values), "names and values must have the same length"
|
||||
|
||||
return [
|
||||
self.swap_tensor(name, value, allow_missing=allow_missing)
|
||||
for name, value in zip(names, values)
|
||||
]
|
||||
|
||||
def swap_tensors_dict(
|
||||
self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
|
||||
) -> Tuple[Dict[str, torch.Tensor], List[str]]:
|
||||
"""
|
||||
Swap the attributes specified by the given paths to values.
|
||||
For example, to swap the attributes mod.layer1.conv1.weight and
|
||||
mod.layer1.conv1.bias, use accessor.swap_tensors_dict({
|
||||
"layer1.conv1.weight": weight,
|
||||
"layer1.conv1.bias": bias,
|
||||
})
|
||||
"""
|
||||
orig_named_tensors = {}
|
||||
missing_keys = []
|
||||
try:
|
||||
for name, tensor in named_tensors.items():
|
||||
orig_tensor = self.swap_tensor(name, tensor, allow_missing=True)
|
||||
if orig_tensor is _MISSING:
|
||||
missing_keys.append(name)
|
||||
orig_named_tensors[name] = orig_tensor
|
||||
except Exception:
|
||||
# Swap back if any exception occurs
|
||||
for name, orig_tensor in orig_named_tensors.items():
|
||||
self.swap_tensor(name, orig_tensor, allow_missing=True)
|
||||
raise
|
||||
if missing_keys and not allow_missing:
|
||||
# Swap back if any key is missing when allow_missing is False
|
||||
for name, orig_tensor in orig_named_tensors.items():
|
||||
self.swap_tensor(name, orig_tensor, allow_missing=True)
|
||||
raise RuntimeError(
|
||||
"Missing key(s): {}.".format(", ".join(map(repr, missing_keys)))
|
||||
)
|
||||
return orig_named_tensors, missing_keys
|
||||
|
||||
def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Check that the given keys are valid.
|
||||
"""
|
||||
keys = set(keys)
|
||||
valid_keys = set(name for name, _ in self.named_tensors(remove_duplicate=False))
|
||||
missing_keys = valid_keys - keys
|
||||
unexpected_keys = keys - valid_keys
|
||||
return sorted(missing_keys), sorted(unexpected_keys)
|
||||
|
||||
# Shortcut methods
|
||||
|
||||
def named_parameters(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Iterate over all the parameters in the module.
|
||||
"""
|
||||
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
|
||||
|
||||
def named_buffers(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Iterate over all the buffers in the module.
|
||||
"""
|
||||
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
|
||||
|
||||
def named_tensors(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Iterate over all the tensors in the module.
|
||||
"""
|
||||
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
|
||||
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
|
||||
|
||||
def named_modules(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, "torch.nn.Module"]]:
|
||||
"""
|
||||
Iterate over all the modules in the module.
|
||||
"""
|
||||
yield from self.module.named_modules(remove_duplicate=remove_duplicate)
|
@ -1,17 +1,25 @@
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Iterator, List, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
||||
|
||||
__all__ = ["functional_call"]
|
||||
|
||||
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
|
||||
# and using other types causes mypy errors
|
||||
# TODO: remove this unreferenced function when `torch.nn.utils._stateless` is removed
|
||||
def _change_class(module, params_and_buffers) -> None:
|
||||
warnings.warn(
|
||||
"The function `torch.nn.utils.stateless._change_class` is private "
|
||||
"and it is deprecated now. It may be removed in a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
cls = module.__class__
|
||||
attr_to_path : Dict[str, str] = module._attr_to_path
|
||||
attr_to_path: Dict[str, str] = module._attr_to_path
|
||||
|
||||
def _getattribute(self, name: str) -> Any:
|
||||
if name in attr_to_path:
|
||||
@ -37,144 +45,169 @@ def _change_class(module, params_and_buffers) -> None:
|
||||
module._orig_class = cls
|
||||
|
||||
|
||||
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
|
||||
def _untie_named_tensors_map(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
) -> Dict[str, Tensor]:
|
||||
"""
|
||||
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
|
||||
Unties all tied tensors in the module to parameters_and_buffers.
|
||||
|
||||
Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
|
||||
This function returns a new untied_parameters_and_buffers dictionary and leave the original
|
||||
untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors
|
||||
in the module to untied_parameters_and_buffers. The value of the new key is the user-given value
|
||||
in the original parameters_and_buffers dictionary.
|
||||
|
||||
ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
|
||||
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
|
||||
{'tied_foo': 'foo'}.
|
||||
If there are more than one user-given values for the same tied tensor, it will raise an error.
|
||||
|
||||
ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
|
||||
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
|
||||
For example, if the module has two tied weights self.foo and self.tied_foo and the user passes
|
||||
{'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the
|
||||
user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the
|
||||
user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error.
|
||||
|
||||
ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
|
||||
user only passed a new value for 'bias', this looks returns: {'bias': set()}
|
||||
Args:
|
||||
module (torch.nn.Module): the module to determine which tensors are tied.
|
||||
parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module.
|
||||
|
||||
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
|
||||
this returned dictionary.
|
||||
Returns:
|
||||
A new untied version of the parameters_and_buffers dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: if there are more than one user-given values for the same tied tensor.
|
||||
"""
|
||||
# A map of {name: tensor} for all tensors (including tied ones) in the module.
|
||||
all_named_tensors: Dict[str, Tensor] = {}
|
||||
all_named_tensors.update(module.named_parameters(remove_duplicate=False))
|
||||
all_named_tensors.update(module.named_buffers(remove_duplicate=False))
|
||||
|
||||
# The basic algorithm looks like:
|
||||
# - index all weights by their original tensor value to find tied weights
|
||||
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
|
||||
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
|
||||
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
|
||||
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
|
||||
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
|
||||
# A map of {tensor: set(all_tied_names)} for all tensor names in the module.
|
||||
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set)
|
||||
for name, tensor in all_named_tensors.items():
|
||||
tensor_to_tied_names_map[tensor].add(name)
|
||||
|
||||
names = params_and_buffers.keys()
|
||||
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
|
||||
# A map of {tied_name: set(all_tied_names)} for all tensor names in the module.
|
||||
# If a name is not tied, it will not be in this map.
|
||||
tied_names_map: Dict[str, Set[str]] = {}
|
||||
for tied_names in tensor_to_tied_names_map.values():
|
||||
if len(tied_names) > 1:
|
||||
for tied_name in tied_names:
|
||||
tied_names_map[tied_name] = tied_names
|
||||
|
||||
# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
|
||||
# part at the end it's (used_name, (tied_names)).
|
||||
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
|
||||
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
|
||||
def add_to_name_map(n: str, t: torch.Tensor):
|
||||
# if the tensor hasn't been seen before, add it to the map
|
||||
if t not in weight_to_name_and_tied_names:
|
||||
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
|
||||
return
|
||||
# Make sure the user didn't pass multiple values for the same tied tensor.
|
||||
given_names = set(parameters_and_buffers.keys())
|
||||
given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys())
|
||||
for given_name in given_names_for_tied_tensors:
|
||||
tied_names = tied_names_map[given_name]
|
||||
if (
|
||||
# Detect if there are multiple keys present for the same tied tensor.
|
||||
len(tied_names.intersection(given_names_for_tied_tensors)) > 1
|
||||
# Only raise an error if the user passed multiple values for the same tied tensor.
|
||||
# If all given values are the same, don't raise.
|
||||
and len({parameters_and_buffers[tied_name] for tied_name in tied_names})
|
||||
!= 1
|
||||
):
|
||||
raise ValueError(
|
||||
f"functional_call got multiple values for keys {sorted(tied_names)}, "
|
||||
f"which are tied. Consider using tie_weights=False"
|
||||
)
|
||||
|
||||
# if the name is not used by the user, we add it to the tied set
|
||||
if n not in names:
|
||||
weight_to_name_and_tied_names[t][1].add(n)
|
||||
return
|
||||
|
||||
# check that the user didn't pass two different tensors for the same tied weight
|
||||
first_seen_name = weight_to_name_and_tied_names[t][0]
|
||||
|
||||
# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
|
||||
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
|
||||
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
|
||||
return
|
||||
|
||||
raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
|
||||
"Consider using tie_weights=False")
|
||||
|
||||
tensor: Tensor
|
||||
for name, tensor in module.named_parameters(remove_duplicate=False):
|
||||
add_to_name_map(name, tensor)
|
||||
|
||||
for name, tensor in module.named_buffers(remove_duplicate=False):
|
||||
add_to_name_map(name, tensor)
|
||||
|
||||
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
|
||||
tied_weights_to_given_name = {}
|
||||
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
|
||||
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
|
||||
continue
|
||||
for tied_name in tied_names:
|
||||
tied_weights_to_given_name[tied_name] = name_given_by_user
|
||||
return tied_weights_to_given_name
|
||||
|
||||
|
||||
def _create_swap_params(params_and_buffers):
|
||||
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
|
||||
# Changes the module class to get a new __getattr__ dunder method
|
||||
# that looks for the reparametrized tensor
|
||||
if hasattr(module, "_attr_to_path"):
|
||||
module._attr_to_path[tensor_name] = full_path
|
||||
else:
|
||||
module._attr_to_path = {}
|
||||
module._attr_to_path[tensor_name] = full_path
|
||||
_change_class(module, params_and_buffers)
|
||||
return _swap_parameters
|
||||
|
||||
|
||||
def _remove_swap(module, name: str, full_path: str) -> None:
|
||||
if hasattr(module, "_orig_class"):
|
||||
module.__class__ = module._orig_class
|
||||
delattr(module, "_orig_class")
|
||||
delattr(module, "_attr_to_path")
|
||||
# Untie the given named tensor map
|
||||
# Make a copy for not modifying the original dict
|
||||
untied_parameters_and_buffers = parameters_and_buffers.copy()
|
||||
for given_name in given_names_for_tied_tensors:
|
||||
for tied_name in tied_names_map[given_name]:
|
||||
untied_parameters_and_buffers[tied_name] = parameters_and_buffers[
|
||||
given_name
|
||||
]
|
||||
return untied_parameters_and_buffers
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _reparametrize_module(
|
||||
module: 'torch.nn.Module',
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
tie_weights: bool = False,
|
||||
*,
|
||||
strict: bool = False,
|
||||
) -> Iterator[None]:
|
||||
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
|
||||
for name, tensor in parameters_and_buffers.items():
|
||||
_apply_func_submodules(
|
||||
_create_swap_params(parameters_and_buffers),
|
||||
module, name.split("."), name, (tensor,))
|
||||
for tied_name, user_given_name in tied_weights_map.items():
|
||||
_apply_func_submodules(
|
||||
_create_swap_params(parameters_and_buffers),
|
||||
module, tied_name.split("."), user_given_name, (None,))
|
||||
if tie_weights:
|
||||
untied_parameters_and_buffers = _untie_named_tensors_map(
|
||||
module, parameters_and_buffers
|
||||
)
|
||||
else:
|
||||
untied_parameters_and_buffers = parameters_and_buffers
|
||||
|
||||
accessor = NamedMemberAccessor(module)
|
||||
if strict:
|
||||
missing_keys, unexpected_keys = accessor.check_keys(
|
||||
untied_parameters_and_buffers
|
||||
)
|
||||
error_msgs = []
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.append(
|
||||
"Unexpected key(s): {}.".format(", ".join(map(repr, unexpected_keys)))
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.append(
|
||||
"Missing key(s): {}.".format(", ".join(map(repr, missing_keys)))
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in reparametrizing for {}:\n\t{}".format(
|
||||
module._get_name(), "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
||||
orig_parameters_and_buffers: Dict[str, Tensor] = {}
|
||||
try:
|
||||
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
untied_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
for name in parameters_and_buffers:
|
||||
_apply_func_submodules(
|
||||
_remove_swap,
|
||||
module, name.split("."), name, ())
|
||||
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
orig_parameters_and_buffers, allow_missing=True
|
||||
)
|
||||
# Sometimes the module is not completely stateless and has some in-place modifications on
|
||||
# the _parameters and _buffers dictionaries.
|
||||
# Write the changed parameters and buffers back to the original dict.
|
||||
parameters_and_buffers.update(
|
||||
{
|
||||
k: new_parameters_and_buffers[k]
|
||||
for k in parameters_and_buffers
|
||||
if k in new_parameters_and_buffers
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# TODO: remove this unreferenced function when `torch.nn.utils._stateless` is removed
|
||||
def _apply_func_submodules(
|
||||
func: Callable[..., None],
|
||||
module: 'torch.nn.Module',
|
||||
module: "torch.nn.Module",
|
||||
path: List[str],
|
||||
full_path: str,
|
||||
args: Tuple,
|
||||
):
|
||||
warnings.warn(
|
||||
"The function `torch.nn.utils.stateless._apply_func_submodules` is private "
|
||||
"and it is deprecated now. It may be removed in a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if len(path) == 1:
|
||||
func(module, path[0], full_path, *args)
|
||||
else:
|
||||
_apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
|
||||
_apply_func_submodules(
|
||||
func, getattr(module, path[0]), path[1:], full_path, args
|
||||
)
|
||||
|
||||
|
||||
def functional_call(
|
||||
module: 'torch.nn.Module',
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
):
|
||||
r"""Performs a functional call on the module by replacing the module parameters
|
||||
and buffers with the provided ones.
|
||||
@ -229,6 +262,9 @@ def functional_call(
|
||||
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
|
||||
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
|
||||
buffers unless the values passed for both weights are the same. Default: True.
|
||||
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
|
||||
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
|
||||
error. Default: False.
|
||||
|
||||
Returns:
|
||||
Any: the result of calling ``module``.
|
||||
@ -236,35 +272,47 @@ def functional_call(
|
||||
warnings.warn(
|
||||
"This API is deprecated as of PyTorch 2.0 and will be removed in a future "
|
||||
"version of PyTorch. Please use torch.func.functional_call instead "
|
||||
"which is a drop-in replacement for this API.")
|
||||
"which is a drop-in replacement for this API."
|
||||
)
|
||||
|
||||
return _functional_call(
|
||||
module,
|
||||
parameters_and_buffers,
|
||||
args,
|
||||
kwargs,
|
||||
tie_weights=tie_weights,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
return _functional_call(module, parameters_and_buffers, args, kwargs,
|
||||
tie_weights=tie_weights)
|
||||
|
||||
def _functional_call(
|
||||
module: 'torch.nn.Module',
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
):
|
||||
# TODO allow kwargs such as unsafe and others for parametrization
|
||||
if (
|
||||
torch.jit.is_tracing()
|
||||
or torch.jit.is_scripting()
|
||||
or isinstance(module, (
|
||||
torch.jit.is_tracing()
|
||||
or torch.jit.is_scripting()
|
||||
or isinstance(
|
||||
module,
|
||||
(
|
||||
torch.jit.RecursiveScriptModule,
|
||||
torch.jit.ScriptModule,
|
||||
torch.jit.ScriptFunction)
|
||||
)
|
||||
torch.jit.ScriptFunction,
|
||||
),
|
||||
)
|
||||
):
|
||||
raise RuntimeError("The stateless API can't be used with Jitted modules")
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
|
||||
if isinstance(args, tuple):
|
||||
out = module(*args, **kwargs)
|
||||
else:
|
||||
out = module(args, **kwargs)
|
||||
return out
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
with _reparametrize_module(
|
||||
module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
|
||||
):
|
||||
return module(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user