mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add assign kwarg to module.load_state_dict (#102212)
Fixes #64601 and #98906 Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`. Primarily intended to remove the need for the `.to_empty()` in ``` with torch.device('meta'): m = SomeModule() m.to_empty() state_dict = torch.load('...pth') m.load_state_dict(state_dict) ``` so we can instead do ``` with torch.device('meta'): m = SomeModule() state_dict = torch.load('...pth') m.load_state_dict(state_dict, assign=True) ``` **A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?** What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these. One solution would be to make these empty tensors but it might not be semantically correct... Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
73be9842be
commit
d1cecd9c32
117
test/test_nn.py
117
test/test_nn.py
@ -2582,6 +2582,123 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
self.assertEqual(m.state_dict(), m2.state_dict())
|
||||
self.assertEqual(m.foo, m2.foo)
|
||||
|
||||
def test_load_state_dict_assign_meta(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict(keep_vars=True)
|
||||
|
||||
with torch.device('meta'):
|
||||
net_meta = MyModule()
|
||||
|
||||
net_meta.load_state_dict(state_dict, assign=True)
|
||||
|
||||
# Make sure parameters and persistent buffers were assigned
|
||||
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
|
||||
for key in state_dict.keys():
|
||||
if isinstance(state_dict[key], torch.nn.Parameter):
|
||||
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
|
||||
|
||||
# Make sure that ordering of parameters and buffers is preserved
|
||||
net_named_parameters = net.named_parameters()
|
||||
net_named_buffers = net.named_buffers()
|
||||
net_meta_named_parameters = net_meta.named_parameters()
|
||||
net_meta_named_buffers = net_meta.named_buffers()
|
||||
|
||||
for p1, p2 in zip(net_named_parameters, net_meta_named_parameters):
|
||||
n1, _ = p1
|
||||
n2, _ = p2
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
for p1, p2 in zip(net_named_buffers, net_meta_named_buffers):
|
||||
n1, _ = p1
|
||||
n2, _ = p2
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
# Make sure outputs are the same
|
||||
t = torch.randn(4, 3)
|
||||
out_net = net(t)
|
||||
out_net_meta = net_meta(t.clone())
|
||||
|
||||
self.assertEqual(out_net, out_net_meta)
|
||||
|
||||
def test_load_state_dict_assign_with_optimizer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
opt = torch.optim.Adam(net.parameters(), lr=1000)
|
||||
x = torch.randn(4, 3)
|
||||
num_iters = 3
|
||||
|
||||
for i in range(num_iters):
|
||||
opt.zero_grad()
|
||||
out = net(x)
|
||||
out.sum().backward()
|
||||
opt.step()
|
||||
|
||||
opt_state_dict = deepcopy(opt.state_dict())
|
||||
net_state_dict = deepcopy(net.state_dict())
|
||||
|
||||
with torch.device('meta'):
|
||||
net_meta = MyModule()
|
||||
|
||||
net_meta.load_state_dict(net_state_dict, assign=True)
|
||||
# must create optimizer only after loading state_dict when assign=True
|
||||
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
|
||||
opt2.load_state_dict(opt_state_dict)
|
||||
|
||||
y = x.clone()
|
||||
for i in range(num_iters):
|
||||
opt.zero_grad()
|
||||
out = net(x)
|
||||
out.sum().backward()
|
||||
opt.step()
|
||||
|
||||
opt2.zero_grad()
|
||||
out2 = net_meta(y)
|
||||
out2.sum().backward()
|
||||
opt2.step()
|
||||
|
||||
self.assertEqual(opt.state_dict(), opt2.state_dict())
|
||||
self.assertEqual(net.state_dict(), net_meta.state_dict())
|
||||
|
||||
def test_load_state_dict_assign_shape_stride(self):
|
||||
# Assigned tensor is allowed to have different properties than initial
|
||||
# tensor except for shape
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict()
|
||||
# loading should be ok if stride is different
|
||||
state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1)
|
||||
net2 = MyModule()
|
||||
net2.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
state_dict['fc1.weight'] = torch.randn(2, 4)
|
||||
with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"):
|
||||
net2.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
def test_extra_state_missing_set_extra_state(self):
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
||||
@ -394,6 +394,7 @@ class _RemoteModule(nn.Module):
|
||||
self,
|
||||
state_dict: Mapping[str, Any],
|
||||
strict: bool = True,
|
||||
assign: bool = False,
|
||||
):
|
||||
_raise_not_supported(self.load_state_dict.__name__)
|
||||
|
||||
|
||||
@ -1896,6 +1896,9 @@ class Module:
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
Additionally, :attr:`local_metadata` can also contain the key
|
||||
`assign_to_params_buffers` that indicates whether keys should be
|
||||
assigned their corresponding tensor in the state_dict.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
@ -1926,6 +1929,7 @@ class Module:
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
@ -1954,7 +1958,15 @@ class Module:
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
if assign_to_params_buffers:
|
||||
# Shape checks are already done above
|
||||
if (isinstance(param, torch.nn.Parameter) and
|
||||
not isinstance(input_param, torch.nn.Parameter)):
|
||||
setattr(self, name, torch.nn.Parameter(input_param))
|
||||
else:
|
||||
setattr(self, name, input_param)
|
||||
else:
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
@ -1982,18 +1994,29 @@ class Module:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any],
|
||||
strict: bool = True):
|
||||
strict: bool = True, assign: bool = False):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
||||
|
||||
.. warning::
|
||||
If :attr:`assign` is ``True`` the optimizer must be created after
|
||||
the call to :attr:`load_state_dict`.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
||||
assign (bool, optional): whether to assign items in the state
|
||||
dictionary to their corresponding keys in the module instead
|
||||
of copying them inplace into the module's current parameters and buffers.
|
||||
When ``False``, the properties of the tensors in the current
|
||||
module are preserved while when ``True``, the properties of the
|
||||
Tensors in the state dict are preserved.
|
||||
Default: ``False``
|
||||
|
||||
Returns:
|
||||
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
||||
@ -2021,6 +2044,8 @@ class Module:
|
||||
|
||||
def load(module, local_state_dict, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata['assign_to_params_buffers'] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
|
||||
Reference in New Issue
Block a user