mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add lightweight reparametrization for _stateless calls (#68969)
				
					
				
			Summary: https://github.com/pytorch/pytorch/issues/61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d (cherry picked from commit fd4b6bdfbff4cb3d1da47b7fd73f1edfe43ba65c)
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							9413c0cd3e
						
					
				
				
					commit
					fa38e93fe9
				
			| @ -58,9 +58,16 @@ class TestStatelessFunctionalAPI(TestCase): | ||||
|         jit_module = torch.jit.script(module) | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, | ||||
|             r'delete methods or parameters' | ||||
|             r'used with Jitted modules' | ||||
|         ): | ||||
|             self._run_call_with_mock_module(jit_module) | ||||
|         x = torch.rand((1, 1)) | ||||
|         traced_module = torch.jit.trace(module, x) | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, | ||||
|             r'used with Jitted modules' | ||||
|         ): | ||||
|             self._run_call_with_mock_module(traced_module) | ||||
|  | ||||
|     @unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported') | ||||
|     def test_functional_call_with_data_parallel(self): | ||||
| @ -127,12 +134,15 @@ class TestStatelessFunctionalAPI(TestCase): | ||||
|         self.assertEqual(cur_weight, prev_weight) | ||||
|         self.assertEqual(cur_buffer, prev_buffer) | ||||
|  | ||||
|     def test_reparametrized_module(self): | ||||
|     def test_reparametrized_module_change_parametrization_original(self): | ||||
|         module = MockModule() | ||||
|         torch.nn.utils.parametrizations.spectral_norm(module.l1) | ||||
|         self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) | ||||
|         orig_sn_weight = module.l1.weight.clone() | ||||
|         x = torch.rand((1, 1)) | ||||
|         # We substitute the parameter inside the parametrization | ||||
|         # the parametrization itself is not overwritten so it will be applied with a different | ||||
|         # value for the original tensor | ||||
|         parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), | ||||
|                       'l1.bias': torch.tensor([0.0]), | ||||
|                       'buffer': torch.tensor([0.0])} | ||||
| @ -142,6 +152,5 @@ class TestStatelessFunctionalAPI(TestCase): | ||||
|         self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) | ||||
|         self.assertEqual(orig_sn_weight, module.l1.weight) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run_tests() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user