Add lightweight reparametrization for _stateless calls (#68969)

Summary:
https://github.com/pytorch/pytorch/issues/61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large.
I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible.

resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters.

Used script:
https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a

| % of parameters passed | CPU Time (us) | GPU Time (us) |
|------------------------|---------------|---------------|
| regular call           | 5539          | 184909        |
| 0                      | 5561          | 184843        |
| 25                     | 11363         | 189236        |
| 50                     | 18716         | 195378        |
| 75                     | 22851         | 198641        |
| 100                    | 27441         | 202281        |

This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor.

The execution times now are as follows:

| % of parameters passed | CPU Time (us) | GPU Time (us) |
|------------------------|---------------|---------------|
| regular call           | 5939          | 187533        |
| 0                      | 5899          | 187570        |
| 25                     | 8541         | 188953        |
| 50                     | 10045         | 189826        |
| 75                     | 11049         | 190344        |
| 100                    | 11911         | 190800        |
| functorch with 100% params | 14014 | 191727

Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap.

cc albanD zou3519

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68969

Reviewed By: george-qi

Differential Revision: D33836360

Pulled By: albanD

fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d
(cherry picked from commit fd4b6bdfbff4cb3d1da47b7fd73f1edfe43ba65c)
This commit is contained in:
Emilio Castillo
2022-01-28 06:34:38 -08:00
committed by PyTorch MergeBot
parent 9413c0cd3e
commit fa38e93fe9
2 changed files with 104 additions and 20 deletions

View File

@ -58,9 +58,16 @@ class TestStatelessFunctionalAPI(TestCase):
jit_module = torch.jit.script(module)
with self.assertRaisesRegex(
RuntimeError,
r'delete methods or parameters'
r'used with Jitted modules'
):
self._run_call_with_mock_module(jit_module)
x = torch.rand((1, 1))
traced_module = torch.jit.trace(module, x)
with self.assertRaisesRegex(
RuntimeError,
r'used with Jitted modules'
):
self._run_call_with_mock_module(traced_module)
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
def test_functional_call_with_data_parallel(self):
@ -127,12 +134,15 @@ class TestStatelessFunctionalAPI(TestCase):
self.assertEqual(cur_weight, prev_weight)
self.assertEqual(cur_buffer, prev_buffer)
def test_reparametrized_module(self):
def test_reparametrized_module_change_parametrization_original(self):
module = MockModule()
torch.nn.utils.parametrizations.spectral_norm(module.l1)
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
orig_sn_weight = module.l1.weight.clone()
x = torch.rand((1, 1))
# We substitute the parameter inside the parametrization
# the parametrization itself is not overwritten so it will be applied with a different
# value for the original tensor
parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])),
'l1.bias': torch.tensor([0.0]),
'buffer': torch.tensor([0.0])}
@ -142,6 +152,5 @@ class TestStatelessFunctionalAPI(TestCase):
self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters()))
self.assertEqual(orig_sn_weight, module.l1.weight)
if __name__ == '__main__':
run_tests()