mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[stateless] add weight tying support (#90477)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90477 Approved by: https://github.com/zou3519
This commit is contained in:
@ -22,6 +22,17 @@ class MockModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.buffer
|
||||
|
||||
class MockTiedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.tied_bias = self.l1.bias
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.register_buffer('tied_buffer', self.buffer)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
|
||||
|
||||
|
||||
class TestStatelessFunctionalAPI(TestCase):
|
||||
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
|
||||
@ -156,7 +167,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
'l1.m.buffer': buffer}
|
||||
prev_weight = module.l1.weight.clone()
|
||||
prev_buffer = module.buffer.clone()
|
||||
res = functional_call(module, parameters, x)
|
||||
res = functional_call(module, parameters, x, tie_weights=False)
|
||||
self.assertEqual(x, res)
|
||||
# check that the weights remain unmodified and were correctly accesed
|
||||
cur_weight = module.l1.weight
|
||||
@ -217,6 +228,46 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
module = MockModule()
|
||||
module.tied_bias = module.l1.bias
|
||||
module.register_buffer("tied_buffer", module.buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_tie_weights(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]],)
|
||||
bias = torch.tensor([5.0])
|
||||
buffer = torch.tensor([3.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
out = functional_call(module, parameters, x, tie_weights=True)
|
||||
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
|
||||
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_reparamertize_tie_some_weights(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[2.0]],)
|
||||
buffer = torch.tensor([3.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
out = stateless.functional_call(module, parameters, x, tie_weights=True)
|
||||
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_tied_weights_errors(self, functional_call):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[1.0]],)
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
@ -225,23 +276,41 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x))
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
|
||||
|
||||
# if tied values are the same tensors, shouldn't warn
|
||||
parameters['tied_bias'] = bias
|
||||
parameters['tied_buffer'] = buffer
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x))
|
||||
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
|
||||
del parameters['tied_bias']
|
||||
del parameters['tied_buffer']
|
||||
|
||||
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
|
||||
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
|
||||
parameters['tied_bias'] = torch.tensor([5.0])
|
||||
functional_call(module, parameters, x)
|
||||
functional_call(module, parameters, x, tie_weights=True)
|
||||
del parameters['tied_bias']
|
||||
|
||||
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
|
||||
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
functional_call(module, parameters, x)
|
||||
functional_call(module, parameters, x, tie_weights=True)
|
||||
|
||||
|
||||
def test_tied_weights_no_error_without_flag(self):
|
||||
module = MockTiedModule()
|
||||
weight = torch.tensor([[1.0]],)
|
||||
bias = torch.tensor([0.0])
|
||||
buffer = torch.tensor([0.0])
|
||||
|
||||
parameters = {'l1.weight': weight,
|
||||
'l1.bias': bias,
|
||||
'buffer': buffer}
|
||||
x = torch.randn(1, 1)
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
|
||||
parameters['tied_bias'] = torch.tensor([5.0])
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
|
||||
del parameters['tied_bias']
|
||||
parameters['tied_buffer'] = torch.tensor([5.0])
|
||||
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
|
Reference in New Issue
Block a user