mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add assign kwarg to module.load_state_dict (#102212)
Fixes #64601 and #98906 Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`. Primarily intended to remove the need for the `.to_empty()` in ``` with torch.device('meta'): m = SomeModule() m.to_empty() state_dict = torch.load('...pth') m.load_state_dict(state_dict) ``` so we can instead do ``` with torch.device('meta'): m = SomeModule() state_dict = torch.load('...pth') m.load_state_dict(state_dict, assign=True) ``` **A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?** What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these. One solution would be to make these empty tensors but it might not be semantically correct... Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212 Approved by: https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							73be9842be
						
					
				
				
					commit
					d1cecd9c32
				
			
							
								
								
									
										117
									
								
								test/test_nn.py
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								test/test_nn.py
									
									
									
									
									
								
							| @ -2582,6 +2582,123 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") | ||||
|             self.assertEqual(m.state_dict(), m2.state_dict()) | ||||
|             self.assertEqual(m.foo, m2.foo) | ||||
|  | ||||
|     def test_load_state_dict_assign_meta(self): | ||||
|         class MyModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.fc1 = nn.Linear(3, 5) | ||||
|                 self.bn = nn.BatchNorm1d(5) | ||||
|  | ||||
|             def forward(self, input): | ||||
|                 return self.bn(self.fc1(input)) | ||||
|  | ||||
|         net = MyModule() | ||||
|         state_dict = net.state_dict(keep_vars=True) | ||||
|  | ||||
|         with torch.device('meta'): | ||||
|             net_meta = MyModule() | ||||
|  | ||||
|         net_meta.load_state_dict(state_dict, assign=True) | ||||
|  | ||||
|         # Make sure parameters and persistent buffers were assigned | ||||
|         net_meta_state_dict = net_meta.state_dict(keep_vars=True) | ||||
|         for key in state_dict.keys(): | ||||
|             if isinstance(state_dict[key], torch.nn.Parameter): | ||||
|                 self.assertTrue(state_dict[key] is net_meta_state_dict[key]) | ||||
|  | ||||
|         # Make sure that ordering of parameters and buffers is preserved | ||||
|         net_named_parameters = net.named_parameters() | ||||
|         net_named_buffers = net.named_buffers() | ||||
|         net_meta_named_parameters = net_meta.named_parameters() | ||||
|         net_meta_named_buffers = net_meta.named_buffers() | ||||
|  | ||||
|         for p1, p2 in zip(net_named_parameters, net_meta_named_parameters): | ||||
|             n1, _ = p1 | ||||
|             n2, _ = p2 | ||||
|             self.assertEqual(n1, n2) | ||||
|  | ||||
|         for p1, p2 in zip(net_named_buffers, net_meta_named_buffers): | ||||
|             n1, _ = p1 | ||||
|             n2, _ = p2 | ||||
|             self.assertEqual(n1, n2) | ||||
|  | ||||
|         # Make sure outputs are the same | ||||
|         t = torch.randn(4, 3) | ||||
|         out_net = net(t) | ||||
|         out_net_meta = net_meta(t.clone()) | ||||
|  | ||||
|         self.assertEqual(out_net, out_net_meta) | ||||
|  | ||||
|     def test_load_state_dict_assign_with_optimizer(self): | ||||
|         class MyModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.fc1 = nn.Linear(3, 5) | ||||
|                 self.bn = nn.BatchNorm1d(5) | ||||
|  | ||||
|             def forward(self, input): | ||||
|                 return self.bn(self.fc1(input)) | ||||
|  | ||||
|         net = MyModule() | ||||
|         opt = torch.optim.Adam(net.parameters(), lr=1000) | ||||
|         x = torch.randn(4, 3) | ||||
|         num_iters = 3 | ||||
|  | ||||
|         for i in range(num_iters): | ||||
|             opt.zero_grad() | ||||
|             out = net(x) | ||||
|             out.sum().backward() | ||||
|             opt.step() | ||||
|  | ||||
|         opt_state_dict = deepcopy(opt.state_dict()) | ||||
|         net_state_dict = deepcopy(net.state_dict()) | ||||
|  | ||||
|         with torch.device('meta'): | ||||
|             net_meta = MyModule() | ||||
|  | ||||
|         net_meta.load_state_dict(net_state_dict, assign=True) | ||||
|         # must create optimizer only after loading state_dict when assign=True | ||||
|         opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000) | ||||
|         opt2.load_state_dict(opt_state_dict) | ||||
|  | ||||
|         y = x.clone() | ||||
|         for i in range(num_iters): | ||||
|             opt.zero_grad() | ||||
|             out = net(x) | ||||
|             out.sum().backward() | ||||
|             opt.step() | ||||
|  | ||||
|             opt2.zero_grad() | ||||
|             out2 = net_meta(y) | ||||
|             out2.sum().backward() | ||||
|             opt2.step() | ||||
|  | ||||
|         self.assertEqual(opt.state_dict(), opt2.state_dict()) | ||||
|         self.assertEqual(net.state_dict(), net_meta.state_dict()) | ||||
|  | ||||
|     def test_load_state_dict_assign_shape_stride(self): | ||||
|         # Assigned tensor is allowed to have different properties than initial | ||||
|         # tensor except for shape | ||||
|         class MyModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.fc1 = nn.Linear(3, 5) | ||||
|                 self.bn = nn.BatchNorm1d(5) | ||||
|  | ||||
|             def forward(self, input): | ||||
|                 return self.bn(self.fc1(input)) | ||||
|  | ||||
|         net = MyModule() | ||||
|         state_dict = net.state_dict() | ||||
|         # loading should be ok if stride is different | ||||
|         state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1) | ||||
|         net2 = MyModule() | ||||
|         net2.load_state_dict(state_dict, strict=False, assign=True) | ||||
|  | ||||
|         state_dict['fc1.weight'] = torch.randn(2, 4) | ||||
|         with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"): | ||||
|             net2.load_state_dict(state_dict, strict=False, assign=True) | ||||
|  | ||||
|     def test_extra_state_missing_set_extra_state(self): | ||||
|  | ||||
|         class MyModule(torch.nn.Module): | ||||
|  | ||||
| @ -394,6 +394,7 @@ class _RemoteModule(nn.Module): | ||||
|         self, | ||||
|         state_dict: Mapping[str, Any], | ||||
|         strict: bool = True, | ||||
|         assign: bool = False, | ||||
|     ): | ||||
|         _raise_not_supported(self.load_state_dict.__name__) | ||||
|  | ||||
|  | ||||
| @ -1896,6 +1896,9 @@ class Module: | ||||
|         For state dicts without metadata, :attr:`local_metadata` is empty. | ||||
|         Subclasses can achieve class-specific backward compatible loading using | ||||
|         the version number at `local_metadata.get("version", None)`. | ||||
|         Additionally, :attr:`local_metadata` can also contain the key | ||||
|         `assign_to_params_buffers` that indicates whether keys should be | ||||
|         assigned their corresponding tensor in the state_dict. | ||||
|  | ||||
|         .. note:: | ||||
|             :attr:`state_dict` is not the same object as the input | ||||
| @ -1926,6 +1929,7 @@ class Module: | ||||
|         persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} | ||||
|         local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) | ||||
|         local_state = {k: v for k, v in local_name_params if v is not None} | ||||
|         assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) | ||||
|  | ||||
|         for name, param in local_state.items(): | ||||
|             key = prefix + name | ||||
| @ -1954,7 +1958,15 @@ class Module: | ||||
|                     continue | ||||
|                 try: | ||||
|                     with torch.no_grad(): | ||||
|                         param.copy_(input_param) | ||||
|                         if assign_to_params_buffers: | ||||
|                             # Shape checks are already done above | ||||
|                             if (isinstance(param, torch.nn.Parameter) and | ||||
|                                     not isinstance(input_param, torch.nn.Parameter)): | ||||
|                                 setattr(self, name, torch.nn.Parameter(input_param)) | ||||
|                             else: | ||||
|                                 setattr(self, name, input_param) | ||||
|                         else: | ||||
|                             param.copy_(input_param) | ||||
|                 except Exception as ex: | ||||
|                     error_msgs.append('While copying the parameter named "{}", ' | ||||
|                                       'whose dimensions in the model are {} and ' | ||||
| @ -1982,18 +1994,29 @@ class Module: | ||||
|                         unexpected_keys.append(key) | ||||
|  | ||||
|     def load_state_dict(self, state_dict: Mapping[str, Any], | ||||
|                         strict: bool = True): | ||||
|                         strict: bool = True, assign: bool = False): | ||||
|         r"""Copies parameters and buffers from :attr:`state_dict` into | ||||
|         this module and its descendants. If :attr:`strict` is ``True``, then | ||||
|         the keys of :attr:`state_dict` must exactly match the keys returned | ||||
|         by this module's :meth:`~torch.nn.Module.state_dict` function. | ||||
|  | ||||
|         .. warning:: | ||||
|             If :attr:`assign` is ``True`` the optimizer must be created after | ||||
|             the call to :attr:`load_state_dict`. | ||||
|  | ||||
|         Args: | ||||
|             state_dict (dict): a dict containing parameters and | ||||
|                 persistent buffers. | ||||
|             strict (bool, optional): whether to strictly enforce that the keys | ||||
|                 in :attr:`state_dict` match the keys returned by this module's | ||||
|                 :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` | ||||
|             assign (bool, optional): whether to assign items in the state | ||||
|                 dictionary to their corresponding keys in the module instead | ||||
|                 of copying them inplace into the module's current parameters and buffers. | ||||
|                 When ``False``, the properties of the tensors in the current | ||||
|                 module are preserved while when ``True``, the properties of the | ||||
|                 Tensors in the state dict are preserved. | ||||
|                 Default: ``False`` | ||||
|  | ||||
|         Returns: | ||||
|             ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: | ||||
| @ -2021,6 +2044,8 @@ class Module: | ||||
|  | ||||
|         def load(module, local_state_dict, prefix=''): | ||||
|             local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | ||||
|             if assign: | ||||
|                 local_metadata['assign_to_params_buffers'] = assign | ||||
|             module._load_from_state_dict( | ||||
|                 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | ||||
|             for name, child in module._modules.items(): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user