mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
fix serialization of nn.Parameter with dill (#10296)
Summary: Should resolve #9981. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10296 Differential Revision: D9196353 Pulled By: soumith fbshipit-source-id: 109b6da42b7240cdbc7a0586745c735bce5e1279
This commit is contained in:
committed by
Facebook Github Bot
parent
1350f76b62
commit
4d28b65fb8
@ -101,6 +101,13 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _rebuild_parameter(data, requires_grad, backward_hooks):
|
||||||
|
param = torch.nn.Parameter(data, requires_grad)
|
||||||
|
param._backward_hooks = backward_hooks
|
||||||
|
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def _import_dotted_name(name):
|
def _import_dotted_name(name):
|
||||||
components = name.split('.')
|
components = name.split('.')
|
||||||
obj = __import__(components[0])
|
obj = __import__(components[0])
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class Parameter(torch.Tensor):
|
|||||||
requires_grad (bool, optional): if the parameter requires gradient. See
|
requires_grad (bool, optional): if the parameter requires gradient. See
|
||||||
:ref:`excluding-subgraphs` for more details. Default: `True`
|
:ref:`excluding-subgraphs` for more details. Default: `True`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, data=None, requires_grad=True):
|
def __new__(cls, data=None, requires_grad=True):
|
||||||
if data is None:
|
if data is None:
|
||||||
data = torch.Tensor()
|
data = torch.Tensor()
|
||||||
@ -27,4 +28,7 @@ class Parameter(torch.Tensor):
|
|||||||
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
|
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
|
||||||
|
|
||||||
def __reduce_ex__(self, proto):
|
def __reduce_ex__(self, proto):
|
||||||
return Parameter, (super(Parameter, self), self.requires_grad)
|
return (
|
||||||
|
torch._utils._rebuild_parameter,
|
||||||
|
(self.data, self.requires_grad, self._backward_hooks)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user