fix serialization of nn.Parameter with dill (#10296)

Summary:
Should resolve #9981.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10296

Differential Revision: D9196353

Pulled By: soumith

fbshipit-source-id: 109b6da42b7240cdbc7a0586745c735bce5e1279
This commit is contained in:
Marcin Elantkowski
2018-09-01 23:53:32 -07:00
committed by Facebook Github Bot
parent 1350f76b62
commit 4d28b65fb8
2 changed files with 12 additions and 1 deletions

View File

@ -101,6 +101,13 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
return tensor return tensor
def _rebuild_parameter(data, requires_grad, backward_hooks):
param = torch.nn.Parameter(data, requires_grad)
param._backward_hooks = backward_hooks
return param
def _import_dotted_name(name): def _import_dotted_name(name):
components = name.split('.') components = name.split('.')
obj = __import__(components[0]) obj = __import__(components[0])

View File

@ -18,6 +18,7 @@ class Parameter(torch.Tensor):
requires_grad (bool, optional): if the parameter requires gradient. See requires_grad (bool, optional): if the parameter requires gradient. See
:ref:`excluding-subgraphs` for more details. Default: `True` :ref:`excluding-subgraphs` for more details. Default: `True`
""" """
def __new__(cls, data=None, requires_grad=True): def __new__(cls, data=None, requires_grad=True):
if data is None: if data is None:
data = torch.Tensor() data = torch.Tensor()
@ -27,4 +28,7 @@ class Parameter(torch.Tensor):
return 'Parameter containing:\n' + super(Parameter, self).__repr__() return 'Parameter containing:\n' + super(Parameter, self).__repr__()
def __reduce_ex__(self, proto): def __reduce_ex__(self, proto):
return Parameter, (super(Parameter, self), self.requires_grad) return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, self._backward_hooks)
)