Add assign kwarg to module.load_state_dict (#102212)

Fixes #64601 and #98906

Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`.

Primarily intended to remove the need for the `.to_empty()` in

```
with torch.device('meta'):
    m = SomeModule()
m.to_empty()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict)
```

so we can instead do

```
with torch.device('meta'):
    m = SomeModule()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict, assign=True)
```

**A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?**
What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these.

One solution would be to make these empty tensors but it might not be semantically correct...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2023-06-15 15:26:29 +00:00
committed by PyTorch MergeBot
parent 73be9842be
commit d1cecd9c32
3 changed files with 145 additions and 2 deletions

View File

@ -2582,6 +2582,123 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
self.assertEqual(m.state_dict(), m2.state_dict())
self.assertEqual(m.foo, m2.foo)
def test_load_state_dict_assign_meta(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
state_dict = net.state_dict(keep_vars=True)
with torch.device('meta'):
net_meta = MyModule()
net_meta.load_state_dict(state_dict, assign=True)
# Make sure parameters and persistent buffers were assigned
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
for key in state_dict.keys():
if isinstance(state_dict[key], torch.nn.Parameter):
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
# Make sure that ordering of parameters and buffers is preserved
net_named_parameters = net.named_parameters()
net_named_buffers = net.named_buffers()
net_meta_named_parameters = net_meta.named_parameters()
net_meta_named_buffers = net_meta.named_buffers()
for p1, p2 in zip(net_named_parameters, net_meta_named_parameters):
n1, _ = p1
n2, _ = p2
self.assertEqual(n1, n2)
for p1, p2 in zip(net_named_buffers, net_meta_named_buffers):
n1, _ = p1
n2, _ = p2
self.assertEqual(n1, n2)
# Make sure outputs are the same
t = torch.randn(4, 3)
out_net = net(t)
out_net_meta = net_meta(t.clone())
self.assertEqual(out_net, out_net_meta)
def test_load_state_dict_assign_with_optimizer(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
opt = torch.optim.Adam(net.parameters(), lr=1000)
x = torch.randn(4, 3)
num_iters = 3
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt_state_dict = deepcopy(opt.state_dict())
net_state_dict = deepcopy(net.state_dict())
with torch.device('meta'):
net_meta = MyModule()
net_meta.load_state_dict(net_state_dict, assign=True)
# must create optimizer only after loading state_dict when assign=True
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
opt2.load_state_dict(opt_state_dict)
y = x.clone()
for i in range(num_iters):
opt.zero_grad()
out = net(x)
out.sum().backward()
opt.step()
opt2.zero_grad()
out2 = net_meta(y)
out2.sum().backward()
opt2.step()
self.assertEqual(opt.state_dict(), opt2.state_dict())
self.assertEqual(net.state_dict(), net_meta.state_dict())
def test_load_state_dict_assign_shape_stride(self):
# Assigned tensor is allowed to have different properties than initial
# tensor except for shape
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, input):
return self.bn(self.fc1(input))
net = MyModule()
state_dict = net.state_dict()
# loading should be ok if stride is different
state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1)
net2 = MyModule()
net2.load_state_dict(state_dict, strict=False, assign=True)
state_dict['fc1.weight'] = torch.randn(2, 4)
with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"):
net2.load_state_dict(state_dict, strict=False, assign=True)
def test_extra_state_missing_set_extra_state(self):
class MyModule(torch.nn.Module):

View File

@ -394,6 +394,7 @@ class _RemoteModule(nn.Module):
self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
):
_raise_not_supported(self.load_state_dict.__name__)

View File

@ -1896,6 +1896,9 @@ class Module:
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
Additionally, :attr:`local_metadata` can also contain the key
`assign_to_params_buffers` that indicates whether keys should be
assigned their corresponding tensor in the state_dict.
.. note::
:attr:`state_dict` is not the same object as the input
@ -1926,6 +1929,7 @@ class Module:
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
for name, param in local_state.items():
key = prefix + name
@ -1954,7 +1958,15 @@ class Module:
continue
try:
with torch.no_grad():
param.copy_(input_param)
if assign_to_params_buffers:
# Shape checks are already done above
if (isinstance(param, torch.nn.Parameter) and
not isinstance(input_param, torch.nn.Parameter)):
setattr(self, name, torch.nn.Parameter(input_param))
else:
setattr(self, name, input_param)
else:
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
@ -1982,18 +1994,29 @@ class Module:
unexpected_keys.append(key)
def load_state_dict(self, state_dict: Mapping[str, Any],
strict: bool = True):
strict: bool = True, assign: bool = False):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
.. warning::
If :attr:`assign` is ``True`` the optimizer must be created after
the call to :attr:`load_state_dict`.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
assign (bool, optional): whether to assign items in the state
dictionary to their corresponding keys in the module instead
of copying them inplace into the module's current parameters and buffers.
When ``False``, the properties of the tensors in the current
module are preserved while when ``True``, the properties of the
Tensors in the state dict are preserved.
Default: ``False``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
@ -2021,6 +2044,8 @@ class Module:
def load(module, local_state_dict, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata['assign_to_params_buffers'] = assign
module._load_from_state_dict(
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():