[stateless] add weight tying support (#90477)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90477
Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2023-01-10 17:02:48 -05:00
committed by PyTorch MergeBot
parent e03ac0ee8c
commit 8b3c4bc481
3 changed files with 202 additions and 24 deletions

View File

@ -22,6 +22,17 @@ class MockModule(torch.nn.Module):
def forward(self, x):
return self.l1(x) + self.buffer
class MockTiedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.tied_bias = self.l1.bias
self.register_buffer('buffer', torch.ones(1))
self.register_buffer('tied_buffer', self.buffer)
def forward(self, x):
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
class TestStatelessFunctionalAPI(TestCase):
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
@ -156,7 +167,7 @@ class TestStatelessFunctionalAPI(TestCase):
'l1.m.buffer': buffer}
prev_weight = module.l1.weight.clone()
prev_buffer = module.buffer.clone()
res = functional_call(module, parameters, x)
res = functional_call(module, parameters, x, tie_weights=False)
self.assertEqual(x, res)
# check that the weights remain unmodified and were correctly accesed
cur_weight = module.l1.weight
@ -217,6 +228,46 @@ class TestStatelessFunctionalAPI(TestCase):
module = MockModule()
module.tied_bias = module.l1.bias
module.register_buffer("tied_buffer", module.buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
out = functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_some_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
buffer = torch.tensor([3.0])
parameters = {'l1.weight': weight,
'buffer': buffer}
x = torch.randn(1, 1)
out = stateless.functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_tied_weights_errors(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
@ -225,23 +276,41 @@ class TestStatelessFunctionalAPI(TestCase):
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: functional_call(module, parameters, x))
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
# if tied values are the same tensors, shouldn't warn
parameters['tied_bias'] = bias
parameters['tied_buffer'] = buffer
self.assertNotWarn(lambda: functional_call(module, parameters, x))
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
del parameters['tied_bias']
del parameters['tied_buffer']
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
parameters['tied_bias'] = torch.tensor([5.0])
functional_call(module, parameters, x)
functional_call(module, parameters, x, tie_weights=True)
del parameters['tied_bias']
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
parameters['tied_buffer'] = torch.tensor([5.0])
functional_call(module, parameters, x)
functional_call(module, parameters, x, tie_weights=True)
def test_tied_weights_no_error_without_flag(self):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
parameters['tied_bias'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
del parameters['tied_bias']
parameters['tied_buffer'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),