mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Resubmit #20698 which got messed up. Idea is that when PyTorch is used in a custom build environment (e.g. Facebook), it's useful to track usage of various APIs centrally. This PR introduces a simple very lightweight mechanism to do so - only first invocation of a trigger point would be logged. This is significantly more lightweight than #18235 and thus we can allow to put logging in e.g. TensorImpl. Also adds an initial list of trigger points. Trigger points are added in such a way that no static initialization triggers them, i.e. just linking with libtorch.so will not cause any logging. Further suggestions of what to log are welcomed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20745 Differential Revision: D15429196 Pulled By: dzhulgakov fbshipit-source-id: a5e41a709a65b7ebccc6b95f93854e583cf20aca
1072 lines
41 KiB
Python
1072 lines
41 KiB
Python
from collections import OrderedDict, namedtuple
|
|
import functools
|
|
import itertools
|
|
|
|
import torch
|
|
from ..backends.thnn import backend as thnn_backend
|
|
from ..parameter import Parameter
|
|
import torch.utils.hooks as hooks
|
|
|
|
|
|
_IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])
|
|
|
|
|
|
def _addindent(s_, numSpaces):
|
|
s = s_.split('\n')
|
|
# don't do anything for single-line stuff
|
|
if len(s) == 1:
|
|
return s_
|
|
first = s.pop(0)
|
|
s = [(numSpaces * ' ') + line for line in s]
|
|
s = '\n'.join(s)
|
|
s = first + '\n' + s
|
|
return s
|
|
|
|
|
|
class Module(object):
|
|
r"""Base class for all neural network modules.
|
|
|
|
Your models should also subclass this class.
|
|
|
|
Modules can also contain other Modules, allowing to nest them in
|
|
a tree structure. You can assign the submodules as regular attributes::
|
|
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 20, 5)
|
|
self.conv2 = nn.Conv2d(20, 20, 5)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
return F.relu(self.conv2(x))
|
|
|
|
Submodules assigned in this way will be registered, and will have their
|
|
parameters converted too when you call :meth:`to`, etc.
|
|
"""
|
|
|
|
dump_patches = False
|
|
|
|
r"""This allows better BC support for :meth:`load_state_dict`. In
|
|
:meth:`state_dict`, the version number will be saved as in the attribute
|
|
`_metadata` of the returned state dict, and thus pickled. `_metadata` is a
|
|
dictionary with keys that follow the naming convention of state dict. See
|
|
``_load_from_state_dict`` on how to use this information in loading.
|
|
|
|
If new parameters/buffers are added/removed from a module, this number shall
|
|
be bumped, and the module's `_load_from_state_dict` method can compare the
|
|
version number and do appropriate changes if the state dict is from before
|
|
the change."""
|
|
_version = 1
|
|
|
|
def __init__(self):
|
|
torch._C._log_api_usage_once("python.nn_module")
|
|
self._backend = thnn_backend
|
|
self._parameters = OrderedDict()
|
|
self._buffers = OrderedDict()
|
|
self._backward_hooks = OrderedDict()
|
|
self._forward_hooks = OrderedDict()
|
|
self._forward_pre_hooks = OrderedDict()
|
|
self._state_dict_hooks = OrderedDict()
|
|
self._load_state_dict_pre_hooks = OrderedDict()
|
|
self._modules = OrderedDict()
|
|
self.training = True
|
|
|
|
def forward(self, *input):
|
|
r"""Defines the computation performed at every call.
|
|
|
|
Should be overridden by all subclasses.
|
|
|
|
.. note::
|
|
Although the recipe for forward pass needs to be defined within
|
|
this function, one should call the :class:`Module` instance afterwards
|
|
instead of this since the former takes care of running the
|
|
registered hooks while the latter silently ignores them.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def register_buffer(self, name, tensor):
|
|
r"""Adds a persistent buffer to the module.
|
|
|
|
This is typically used to register a buffer that should not to be
|
|
considered a model parameter. For example, BatchNorm's ``running_mean``
|
|
is not a parameter, but is part of the persistent state.
|
|
|
|
Buffers can be accessed as attributes using given names.
|
|
|
|
Args:
|
|
name (string): name of the buffer. The buffer can be accessed
|
|
from this module using the given name
|
|
tensor (Tensor): buffer to be registered.
|
|
|
|
Example::
|
|
|
|
>>> self.register_buffer('running_mean', torch.zeros(num_features))
|
|
|
|
"""
|
|
if '_buffers' not in self.__dict__:
|
|
raise AttributeError(
|
|
"cannot assign buffer before Module.__init__() call")
|
|
elif not isinstance(name, torch._six.string_classes):
|
|
raise TypeError("buffer name should be a string. "
|
|
"Got {}".format(torch.typename(name)))
|
|
elif '.' in name:
|
|
raise KeyError("buffer name can't contain \".\"")
|
|
elif name == '':
|
|
raise KeyError("buffer name can't be empty string \"\"")
|
|
elif hasattr(self, name) and name not in self._buffers:
|
|
raise KeyError("attribute '{}' already exists".format(name))
|
|
elif tensor is not None and not isinstance(tensor, torch.Tensor):
|
|
raise TypeError("cannot assign '{}' object to buffer '{}' "
|
|
"(torch Tensor or None required)"
|
|
.format(torch.typename(tensor), name))
|
|
else:
|
|
self._buffers[name] = tensor
|
|
|
|
def register_parameter(self, name, param):
|
|
r"""Adds a parameter to the module.
|
|
|
|
The parameter can be accessed as an attribute using given name.
|
|
|
|
Args:
|
|
name (string): name of the parameter. The parameter can be accessed
|
|
from this module using the given name
|
|
param (Parameter): parameter to be added to the module.
|
|
"""
|
|
if '_parameters' not in self.__dict__:
|
|
raise AttributeError(
|
|
"cannot assign parameter before Module.__init__() call")
|
|
|
|
elif not isinstance(name, torch._six.string_classes):
|
|
raise TypeError("parameter name should be a string. "
|
|
"Got {}".format(torch.typename(name)))
|
|
elif '.' in name:
|
|
raise KeyError("parameter name can't contain \".\"")
|
|
elif name == '':
|
|
raise KeyError("parameter name can't be empty string \"\"")
|
|
elif hasattr(self, name) and name not in self._parameters:
|
|
raise KeyError("attribute '{}' already exists".format(name))
|
|
|
|
if param is None:
|
|
self._parameters[name] = None
|
|
elif not isinstance(param, Parameter):
|
|
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
|
"(torch.nn.Parameter or None required)"
|
|
.format(torch.typename(param), name))
|
|
elif param.grad_fn:
|
|
raise ValueError(
|
|
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
|
|
"parameters must be created explicitly. To express '{0}' "
|
|
"as a function of another Tensor, compute the value in "
|
|
"the forward() method.".format(name))
|
|
else:
|
|
self._parameters[name] = param
|
|
|
|
def add_module(self, name, module):
|
|
r"""Adds a child module to the current module.
|
|
|
|
The module can be accessed as an attribute using the given name.
|
|
|
|
Args:
|
|
name (string): name of the child module. The child module can be
|
|
accessed from this module using the given name
|
|
module (Module): child module to be added to the module.
|
|
"""
|
|
if not isinstance(module, Module) and module is not None:
|
|
raise TypeError("{} is not a Module subclass".format(
|
|
torch.typename(module)))
|
|
elif not isinstance(name, torch._six.string_classes):
|
|
raise TypeError("module name should be a string. Got {}".format(
|
|
torch.typename(name)))
|
|
elif hasattr(self, name) and name not in self._modules:
|
|
raise KeyError("attribute '{}' already exists".format(name))
|
|
elif '.' in name:
|
|
raise KeyError("module name can't contain \".\"")
|
|
elif name == '':
|
|
raise KeyError("module name can't be empty string \"\"")
|
|
self._modules[name] = module
|
|
|
|
def _apply(self, fn):
|
|
for module in self.children():
|
|
module._apply(fn)
|
|
|
|
for param in self._parameters.values():
|
|
if param is not None:
|
|
# Tensors stored in modules are graph leaves, and we don't
|
|
# want to create copy nodes, so we have to unpack the data.
|
|
param.data = fn(param.data)
|
|
if param._grad is not None:
|
|
param._grad.data = fn(param._grad.data)
|
|
|
|
for key, buf in self._buffers.items():
|
|
if buf is not None:
|
|
self._buffers[key] = fn(buf)
|
|
|
|
return self
|
|
|
|
def apply(self, fn):
|
|
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
|
|
as well as self. Typical use includes initializing the parameters of a model
|
|
(see also :ref:`torch-nn-init`).
|
|
|
|
Args:
|
|
fn (:class:`Module` -> None): function to be applied to each submodule
|
|
|
|
Returns:
|
|
Module: self
|
|
|
|
Example::
|
|
|
|
>>> def init_weights(m):
|
|
>>> print(m)
|
|
>>> if type(m) == nn.Linear:
|
|
>>> m.weight.data.fill_(1.0)
|
|
>>> print(m.weight)
|
|
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
|
>>> net.apply(init_weights)
|
|
Linear(in_features=2, out_features=2, bias=True)
|
|
Parameter containing:
|
|
tensor([[ 1., 1.],
|
|
[ 1., 1.]])
|
|
Linear(in_features=2, out_features=2, bias=True)
|
|
Parameter containing:
|
|
tensor([[ 1., 1.],
|
|
[ 1., 1.]])
|
|
Sequential(
|
|
(0): Linear(in_features=2, out_features=2, bias=True)
|
|
(1): Linear(in_features=2, out_features=2, bias=True)
|
|
)
|
|
Sequential(
|
|
(0): Linear(in_features=2, out_features=2, bias=True)
|
|
(1): Linear(in_features=2, out_features=2, bias=True)
|
|
)
|
|
"""
|
|
for module in self.children():
|
|
module.apply(fn)
|
|
fn(self)
|
|
return self
|
|
|
|
def cuda(self, device=None):
|
|
r"""Moves all model parameters and buffers to the GPU.
|
|
|
|
This also makes associated parameters and buffers different objects. So
|
|
it should be called before constructing optimizer if the module will
|
|
live on GPU while being optimized.
|
|
|
|
Arguments:
|
|
device (int, optional): if specified, all parameters will be
|
|
copied to that device
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
return self._apply(lambda t: t.cuda(device))
|
|
|
|
def cpu(self):
|
|
r"""Moves all model parameters and buffers to the CPU.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
return self._apply(lambda t: t.cpu())
|
|
|
|
def type(self, dst_type):
|
|
r"""Casts all parameters and buffers to :attr:`dst_type`.
|
|
|
|
Arguments:
|
|
dst_type (type or string): the desired type
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
return self._apply(lambda t: t.type(dst_type))
|
|
|
|
def float(self):
|
|
r"""Casts all floating point parameters and buffers to float datatype.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
return self._apply(lambda t: t.float() if t.is_floating_point() else t)
|
|
|
|
def double(self):
|
|
r"""Casts all floating point parameters and buffers to ``double`` datatype.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
return self._apply(lambda t: t.double() if t.is_floating_point() else t)
|
|
|
|
def half(self):
|
|
r"""Casts all floating point parameters and buffers to ``half`` datatype.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
return self._apply(lambda t: t.half() if t.is_floating_point() else t)
|
|
|
|
def to(self, *args, **kwargs):
|
|
r"""Moves and/or casts the parameters and buffers.
|
|
|
|
This can be called as
|
|
|
|
.. function:: to(device=None, dtype=None, non_blocking=False)
|
|
|
|
.. function:: to(dtype, non_blocking=False)
|
|
|
|
.. function:: to(tensor, non_blocking=False)
|
|
|
|
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
|
|
floating point desired :attr:`dtype` s. In addition, this method will
|
|
only cast the floating point parameters and buffers to :attr:`dtype`
|
|
(if given). The integral parameters and buffers will be moved
|
|
:attr:`device`, if that is given, but with dtypes unchanged. When
|
|
:attr:`non_blocking` is set, it tries to convert/move asynchronously
|
|
with respect to the host if possible, e.g., moving CPU Tensors with
|
|
pinned memory to CUDA devices.
|
|
|
|
See below for examples.
|
|
|
|
.. note::
|
|
This method modifies the module in-place.
|
|
|
|
Args:
|
|
device (:class:`torch.device`): the desired device of the parameters
|
|
and buffers in this module
|
|
dtype (:class:`torch.dtype`): the desired floating point type of
|
|
the floating point parameters and buffers in this module
|
|
tensor (torch.Tensor): Tensor whose dtype and device are the desired
|
|
dtype and device for all parameters and buffers in this module
|
|
|
|
Returns:
|
|
Module: self
|
|
|
|
Example::
|
|
|
|
>>> linear = nn.Linear(2, 2)
|
|
>>> linear.weight
|
|
Parameter containing:
|
|
tensor([[ 0.1913, -0.3420],
|
|
[-0.5113, -0.2325]])
|
|
>>> linear.to(torch.double)
|
|
Linear(in_features=2, out_features=2, bias=True)
|
|
>>> linear.weight
|
|
Parameter containing:
|
|
tensor([[ 0.1913, -0.3420],
|
|
[-0.5113, -0.2325]], dtype=torch.float64)
|
|
>>> gpu1 = torch.device("cuda:1")
|
|
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
|
|
Linear(in_features=2, out_features=2, bias=True)
|
|
>>> linear.weight
|
|
Parameter containing:
|
|
tensor([[ 0.1914, -0.3420],
|
|
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
|
|
>>> cpu = torch.device("cpu")
|
|
>>> linear.to(cpu)
|
|
Linear(in_features=2, out_features=2, bias=True)
|
|
>>> linear.weight
|
|
Parameter containing:
|
|
tensor([[ 0.1914, -0.3420],
|
|
[-0.5112, -0.2324]], dtype=torch.float16)
|
|
|
|
"""
|
|
|
|
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
|
|
|
|
if dtype is not None:
|
|
if not dtype.is_floating_point:
|
|
raise TypeError('nn.Module.to only accepts floating point '
|
|
'dtypes, but got desired dtype={}'.format(dtype))
|
|
|
|
def convert(t):
|
|
return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
|
|
|
|
return self._apply(convert)
|
|
|
|
def register_backward_hook(self, hook):
|
|
r"""Registers a backward hook on the module.
|
|
|
|
The hook will be called every time the gradients with respect to module
|
|
inputs are computed. The hook should have the following signature::
|
|
|
|
hook(module, grad_input, grad_output) -> Tensor or None
|
|
|
|
The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
|
|
module has multiple inputs or outputs. The hook should not modify its
|
|
arguments, but it can optionally return a new gradient with respect to
|
|
input that will be used in place of :attr:`grad_input` in subsequent
|
|
computations.
|
|
|
|
Returns:
|
|
:class:`torch.utils.hooks.RemovableHandle`:
|
|
a handle that can be used to remove the added hook by calling
|
|
``handle.remove()``
|
|
|
|
.. warning ::
|
|
|
|
The current implementation will not have the presented behavior
|
|
for complex :class:`Module` that perform many operations.
|
|
In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
|
|
contain the gradients for a subset of the inputs and outputs.
|
|
For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
|
|
directly on a specific input or output to get the required gradients.
|
|
|
|
"""
|
|
handle = hooks.RemovableHandle(self._backward_hooks)
|
|
self._backward_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_forward_pre_hook(self, hook):
|
|
r"""Registers a forward pre-hook on the module.
|
|
|
|
The hook will be called every time before :func:`forward` is invoked.
|
|
It should have the following signature::
|
|
|
|
hook(module, input) -> None
|
|
|
|
The hook should not modify the input.
|
|
|
|
Returns:
|
|
:class:`torch.utils.hooks.RemovableHandle`:
|
|
a handle that can be used to remove the added hook by calling
|
|
``handle.remove()``
|
|
"""
|
|
handle = hooks.RemovableHandle(self._forward_pre_hooks)
|
|
self._forward_pre_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_forward_hook(self, hook):
|
|
r"""Registers a forward hook on the module.
|
|
|
|
The hook will be called every time after :func:`forward` has computed an output.
|
|
It should have the following signature::
|
|
|
|
hook(module, input, output) -> None
|
|
|
|
The hook should not modify the input or output.
|
|
|
|
Returns:
|
|
:class:`torch.utils.hooks.RemovableHandle`:
|
|
a handle that can be used to remove the added hook by calling
|
|
``handle.remove()``
|
|
"""
|
|
handle = hooks.RemovableHandle(self._forward_hooks)
|
|
self._forward_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def _tracing_name(self, tracing_state):
|
|
if not tracing_state._traced_module_stack:
|
|
return None
|
|
module = tracing_state._traced_module_stack[-1]
|
|
for name, child in module.named_children():
|
|
if child is self:
|
|
return name
|
|
return None
|
|
|
|
def _slow_forward(self, *input, **kwargs):
|
|
tracing_state = torch._C._get_tracing_state()
|
|
if not tracing_state:
|
|
return self.forward(*input, **kwargs)
|
|
if not hasattr(tracing_state, '_traced_module_stack'):
|
|
tracing_state._traced_module_stack = []
|
|
name = self._tracing_name(tracing_state)
|
|
if name:
|
|
tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
|
|
else:
|
|
tracing_state.push_scope(self._get_name())
|
|
tracing_state._traced_module_stack.append(self)
|
|
try:
|
|
result = self.forward(*input, **kwargs)
|
|
finally:
|
|
tracing_state.pop_scope()
|
|
tracing_state._traced_module_stack.pop()
|
|
return result
|
|
|
|
def __call__(self, *input, **kwargs):
|
|
for hook in self._forward_pre_hooks.values():
|
|
hook(self, input)
|
|
if torch._C._get_tracing_state():
|
|
result = self._slow_forward(*input, **kwargs)
|
|
else:
|
|
result = self.forward(*input, **kwargs)
|
|
for hook in self._forward_hooks.values():
|
|
hook_result = hook(self, input, result)
|
|
if hook_result is not None:
|
|
raise RuntimeError(
|
|
"forward hooks should never return any values, but '{}'"
|
|
"didn't return None".format(hook))
|
|
if len(self._backward_hooks) > 0:
|
|
var = result
|
|
while not isinstance(var, torch.Tensor):
|
|
if isinstance(var, dict):
|
|
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
|
|
else:
|
|
var = var[0]
|
|
grad_fn = var.grad_fn
|
|
if grad_fn is not None:
|
|
for hook in self._backward_hooks.values():
|
|
wrapper = functools.partial(hook, self)
|
|
functools.update_wrapper(wrapper, hook)
|
|
grad_fn.register_hook(wrapper)
|
|
return result
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
# Support loading old checkpoints that don't have the following attrs:
|
|
if '_forward_pre_hooks' not in self.__dict__:
|
|
self._forward_pre_hooks = OrderedDict()
|
|
if '_state_dict_hooks' not in self.__dict__:
|
|
self._state_dict_hooks = OrderedDict()
|
|
if '_load_state_dict_pre_hooks' not in self.__dict__:
|
|
self._load_state_dict_pre_hooks = OrderedDict()
|
|
|
|
def __getattr__(self, name):
|
|
if '_parameters' in self.__dict__:
|
|
_parameters = self.__dict__['_parameters']
|
|
if name in _parameters:
|
|
return _parameters[name]
|
|
if '_buffers' in self.__dict__:
|
|
_buffers = self.__dict__['_buffers']
|
|
if name in _buffers:
|
|
return _buffers[name]
|
|
if '_modules' in self.__dict__:
|
|
modules = self.__dict__['_modules']
|
|
if name in modules:
|
|
return modules[name]
|
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
|
type(self).__name__, name))
|
|
|
|
def __setattr__(self, name, value):
|
|
def remove_from(*dicts):
|
|
for d in dicts:
|
|
if name in d:
|
|
del d[name]
|
|
|
|
params = self.__dict__.get('_parameters')
|
|
if isinstance(value, Parameter):
|
|
if params is None:
|
|
raise AttributeError(
|
|
"cannot assign parameters before Module.__init__() call")
|
|
remove_from(self.__dict__, self._buffers, self._modules)
|
|
self.register_parameter(name, value)
|
|
elif params is not None and name in params:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' as parameter '{}' "
|
|
"(torch.nn.Parameter or None expected)"
|
|
.format(torch.typename(value), name))
|
|
self.register_parameter(name, value)
|
|
else:
|
|
modules = self.__dict__.get('_modules')
|
|
if isinstance(value, Module):
|
|
if modules is None:
|
|
raise AttributeError(
|
|
"cannot assign module before Module.__init__() call")
|
|
remove_from(self.__dict__, self._parameters, self._buffers)
|
|
modules[name] = value
|
|
elif modules is not None and name in modules:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' as child module '{}' "
|
|
"(torch.nn.Module or None expected)"
|
|
.format(torch.typename(value), name))
|
|
modules[name] = value
|
|
else:
|
|
buffers = self.__dict__.get('_buffers')
|
|
if buffers is not None and name in buffers:
|
|
if value is not None and not isinstance(value, torch.Tensor):
|
|
raise TypeError("cannot assign '{}' as buffer '{}' "
|
|
"(torch.Tensor or None expected)"
|
|
.format(torch.typename(value), name))
|
|
buffers[name] = value
|
|
else:
|
|
object.__setattr__(self, name, value)
|
|
|
|
def __delattr__(self, name):
|
|
if name in self._parameters:
|
|
del self._parameters[name]
|
|
elif name in self._buffers:
|
|
del self._buffers[name]
|
|
elif name in self._modules:
|
|
del self._modules[name]
|
|
else:
|
|
object.__delattr__(self, name)
|
|
|
|
def _register_state_dict_hook(self, hook):
|
|
r"""These hooks will be called with arguments: `self`, `state_dict`,
|
|
`prefix`, `local_metadata`, after the `state_dict` of `self` is set.
|
|
Note that only parameters and buffers of `self` or its children are
|
|
guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
|
|
inplace or return a new one.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._state_dict_hooks)
|
|
self._state_dict_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
|
r"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are
|
|
included. Keys are corresponding parameter and buffer names.
|
|
|
|
Returns:
|
|
dict:
|
|
a dictionary containing a whole state of the module
|
|
|
|
Example::
|
|
|
|
>>> module.state_dict().keys()
|
|
['bias', 'weight']
|
|
|
|
"""
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
destination._metadata = OrderedDict()
|
|
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
destination[prefix + name] = param if keep_vars else param.data
|
|
for name, buf in self._buffers.items():
|
|
if buf is not None:
|
|
destination[prefix + name] = buf if keep_vars else buf.data
|
|
for name, module in self._modules.items():
|
|
if module is not None:
|
|
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
|
|
for hook in self._state_dict_hooks.values():
|
|
hook_result = hook(self, destination, prefix, local_metadata)
|
|
if hook_result is not None:
|
|
destination = hook_result
|
|
return destination
|
|
|
|
def _register_load_state_dict_pre_hook(self, hook):
|
|
r"""These hooks will be called with arguments: `state_dict`, `prefix`,
|
|
`local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
|
|
`error_msgs`, before loading `state_dict` into `self`. These arguments
|
|
are exactly the same as those of `_load_from_state_dict`.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
|
|
self._load_state_dict_pre_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
|
this module, but not its descendants. This is called on every submodule
|
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
Subclasses can achieve class-specific backward compatible loading using
|
|
the version number at `local_metadata.get("version", None)`.
|
|
|
|
.. note::
|
|
:attr:`state_dict` is not the same object as the input
|
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
|
it can be modified.
|
|
|
|
Arguments:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
local_metadata (dict): a dict containing the metadata for this module.
|
|
See
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
parameters and buffers in this module
|
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
this list
|
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
keys to this list
|
|
error_msgs (list of str): error messages should be added to this
|
|
list, and will be reported together in
|
|
:meth:`~torch.nn.Module.load_state_dict`
|
|
"""
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
|
|
local_state = {k: v.data for k, v in local_name_params if v is not None}
|
|
|
|
for name, param in local_state.items():
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
input_param = state_dict[key]
|
|
|
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
if len(param.shape) == 0 and len(input_param.shape) == 1:
|
|
input_param = input_param[0]
|
|
|
|
if input_param.shape != param.shape:
|
|
# local shape should match the one in checkpoint
|
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
|
'the shape in current model is {}.'
|
|
.format(key, input_param.shape, param.shape))
|
|
continue
|
|
|
|
if isinstance(input_param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
input_param = input_param.data
|
|
try:
|
|
param.copy_(input_param)
|
|
except Exception:
|
|
error_msgs.append('While copying the parameter named "{}", '
|
|
'whose dimensions in the model are {} and '
|
|
'whose dimensions in the checkpoint are {}.'
|
|
.format(key, param.size(), input_param.size()))
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
|
|
if strict:
|
|
for key in state_dict.keys():
|
|
if key.startswith(prefix):
|
|
input_name = key[len(prefix):]
|
|
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
|
if input_name not in self._modules and input_name not in local_state:
|
|
unexpected_keys.append(key)
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into
|
|
this module and its descendants. If :attr:`strict` is ``True``, then
|
|
the keys of :attr:`state_dict` must exactly match the keys returned
|
|
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
|
|
|
Arguments:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
strict (bool, optional): whether to strictly enforce that the keys
|
|
in :attr:`state_dict` match the keys returned by this module's
|
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
|
|
|
Returns:
|
|
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
|
* **missing_keys** is a list of str containing the missing keys
|
|
* **unexpected_keys** is a list of str containing the unexpected keys
|
|
"""
|
|
missing_keys = []
|
|
unexpected_keys = []
|
|
error_msgs = []
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
state_dict._metadata = metadata
|
|
|
|
def load(module, prefix=''):
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
module._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
load(child, prefix + name + '.')
|
|
|
|
load(self)
|
|
load = None # break load->load reference cycle
|
|
|
|
if strict:
|
|
if len(unexpected_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
|
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
|
if len(missing_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Missing key(s) in state_dict: {}. '.format(
|
|
', '.join('"{}"'.format(k) for k in missing_keys)))
|
|
|
|
if len(error_msgs) > 0:
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
self.__class__.__name__, "\n\t".join(error_msgs)))
|
|
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
|
|
def _named_members(self, get_members_fn, prefix='', recurse=True):
|
|
r"""Helper method for yielding various names + members of modules."""
|
|
memo = set()
|
|
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
|
|
for module_prefix, module in modules:
|
|
members = get_members_fn(module)
|
|
for k, v in members:
|
|
if v is None or v in memo:
|
|
continue
|
|
memo.add(v)
|
|
name = module_prefix + ('.' if module_prefix else '') + k
|
|
yield name, v
|
|
|
|
def parameters(self, recurse=True):
|
|
r"""Returns an iterator over module parameters.
|
|
|
|
This is typically passed to an optimizer.
|
|
|
|
Args:
|
|
recurse (bool): if True, then yields parameters of this module
|
|
and all submodules. Otherwise, yields only parameters that
|
|
are direct members of this module.
|
|
|
|
Yields:
|
|
Parameter: module parameter
|
|
|
|
Example::
|
|
|
|
>>> for param in model.parameters():
|
|
>>> print(type(param.data), param.size())
|
|
<class 'torch.FloatTensor'> (20L,)
|
|
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
|
|
|
"""
|
|
for name, param in self.named_parameters(recurse=recurse):
|
|
yield param
|
|
|
|
def named_parameters(self, prefix='', recurse=True):
|
|
r"""Returns an iterator over module parameters, yielding both the
|
|
name of the parameter as well as the parameter itself.
|
|
|
|
Args:
|
|
prefix (str): prefix to prepend to all parameter names.
|
|
recurse (bool): if True, then yields parameters of this module
|
|
and all submodules. Otherwise, yields only parameters that
|
|
are direct members of this module.
|
|
|
|
Yields:
|
|
(string, Parameter): Tuple containing the name and parameter
|
|
|
|
Example::
|
|
|
|
>>> for name, param in self.named_parameters():
|
|
>>> if name in ['bias']:
|
|
>>> print(param.size())
|
|
|
|
"""
|
|
gen = self._named_members(
|
|
lambda module: module._parameters.items(),
|
|
prefix=prefix, recurse=recurse)
|
|
for elem in gen:
|
|
yield elem
|
|
|
|
def buffers(self, recurse=True):
|
|
r"""Returns an iterator over module buffers.
|
|
|
|
Args:
|
|
recurse (bool): if True, then yields buffers of this module
|
|
and all submodules. Otherwise, yields only buffers that
|
|
are direct members of this module.
|
|
|
|
Yields:
|
|
torch.Tensor: module buffer
|
|
|
|
Example::
|
|
|
|
>>> for buf in model.buffers():
|
|
>>> print(type(buf.data), buf.size())
|
|
<class 'torch.FloatTensor'> (20L,)
|
|
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
|
|
|
"""
|
|
for name, buf in self.named_buffers(recurse=recurse):
|
|
yield buf
|
|
|
|
def named_buffers(self, prefix='', recurse=True):
|
|
r"""Returns an iterator over module buffers, yielding both the
|
|
name of the buffer as well as the buffer itself.
|
|
|
|
Args:
|
|
prefix (str): prefix to prepend to all buffer names.
|
|
recurse (bool): if True, then yields buffers of this module
|
|
and all submodules. Otherwise, yields only buffers that
|
|
are direct members of this module.
|
|
|
|
Yields:
|
|
(string, torch.Tensor): Tuple containing the name and buffer
|
|
|
|
Example::
|
|
|
|
>>> for name, buf in self.named_buffers():
|
|
>>> if name in ['running_var']:
|
|
>>> print(buf.size())
|
|
|
|
"""
|
|
gen = self._named_members(
|
|
lambda module: module._buffers.items(),
|
|
prefix=prefix, recurse=recurse)
|
|
for elem in gen:
|
|
yield elem
|
|
|
|
def children(self):
|
|
r"""Returns an iterator over immediate children modules.
|
|
|
|
Yields:
|
|
Module: a child module
|
|
"""
|
|
for name, module in self.named_children():
|
|
yield module
|
|
|
|
def named_children(self):
|
|
r"""Returns an iterator over immediate children modules, yielding both
|
|
the name of the module as well as the module itself.
|
|
|
|
Yields:
|
|
(string, Module): Tuple containing a name and child module
|
|
|
|
Example::
|
|
|
|
>>> for name, module in model.named_children():
|
|
>>> if name in ['conv4', 'conv5']:
|
|
>>> print(module)
|
|
|
|
"""
|
|
memo = set()
|
|
for name, module in self._modules.items():
|
|
if module is not None and module not in memo:
|
|
memo.add(module)
|
|
yield name, module
|
|
|
|
def modules(self):
|
|
r"""Returns an iterator over all modules in the network.
|
|
|
|
Yields:
|
|
Module: a module in the network
|
|
|
|
Note:
|
|
Duplicate modules are returned only once. In the following
|
|
example, ``l`` will be returned only once.
|
|
|
|
Example::
|
|
|
|
>>> l = nn.Linear(2, 2)
|
|
>>> net = nn.Sequential(l, l)
|
|
>>> for idx, m in enumerate(net.modules()):
|
|
print(idx, '->', m)
|
|
|
|
0 -> Sequential(
|
|
(0): Linear(in_features=2, out_features=2, bias=True)
|
|
(1): Linear(in_features=2, out_features=2, bias=True)
|
|
)
|
|
1 -> Linear(in_features=2, out_features=2, bias=True)
|
|
|
|
"""
|
|
for name, module in self.named_modules():
|
|
yield module
|
|
|
|
def named_modules(self, memo=None, prefix=''):
|
|
r"""Returns an iterator over all modules in the network, yielding
|
|
both the name of the module as well as the module itself.
|
|
|
|
Yields:
|
|
(string, Module): Tuple of name and module
|
|
|
|
Note:
|
|
Duplicate modules are returned only once. In the following
|
|
example, ``l`` will be returned only once.
|
|
|
|
Example::
|
|
|
|
>>> l = nn.Linear(2, 2)
|
|
>>> net = nn.Sequential(l, l)
|
|
>>> for idx, m in enumerate(net.named_modules()):
|
|
print(idx, '->', m)
|
|
|
|
0 -> ('', Sequential(
|
|
(0): Linear(in_features=2, out_features=2, bias=True)
|
|
(1): Linear(in_features=2, out_features=2, bias=True)
|
|
))
|
|
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
|
|
|
|
"""
|
|
|
|
if memo is None:
|
|
memo = set()
|
|
if self not in memo:
|
|
memo.add(self)
|
|
yield prefix, self
|
|
for name, module in self._modules.items():
|
|
if module is None:
|
|
continue
|
|
submodule_prefix = prefix + ('.' if prefix else '') + name
|
|
for m in module.named_modules(memo, submodule_prefix):
|
|
yield m
|
|
|
|
def train(self, mode=True):
|
|
r"""Sets the module in training mode.
|
|
|
|
This has any effect only on certain modules. See documentations of
|
|
particular modules for details of their behaviors in training/evaluation
|
|
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
|
|
etc.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
self.training = mode
|
|
for module in self.children():
|
|
module.train(mode)
|
|
return self
|
|
|
|
def eval(self):
|
|
r"""Sets the module in evaluation mode.
|
|
|
|
This has any effect only on certain modules. See documentations of
|
|
particular modules for details of their behaviors in training/evaluation
|
|
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
|
|
etc.
|
|
"""
|
|
return self.train(False)
|
|
|
|
def zero_grad(self):
|
|
r"""Sets gradients of all model parameters to zero."""
|
|
for p in self.parameters():
|
|
if p.grad is not None:
|
|
p.grad.detach_()
|
|
p.grad.zero_()
|
|
|
|
def share_memory(self):
|
|
return self._apply(lambda t: t.share_memory_())
|
|
|
|
def _get_name(self):
|
|
return self.__class__.__name__
|
|
|
|
def extra_repr(self):
|
|
r"""Set the extra representation of the module
|
|
|
|
To print customized extra information, you should reimplement
|
|
this method in your own modules. Both single-line and multi-line
|
|
strings are acceptable.
|
|
"""
|
|
return ''
|
|
|
|
def __repr__(self):
|
|
# We treat the extra repr like the sub-module, one item per line
|
|
extra_lines = []
|
|
extra_repr = self.extra_repr()
|
|
# empty string will be split into list ['']
|
|
if extra_repr:
|
|
extra_lines = extra_repr.split('\n')
|
|
child_lines = []
|
|
for key, module in self._modules.items():
|
|
mod_str = repr(module)
|
|
mod_str = _addindent(mod_str, 2)
|
|
child_lines.append('(' + key + '): ' + mod_str)
|
|
lines = extra_lines + child_lines
|
|
|
|
main_str = self._get_name() + '('
|
|
if lines:
|
|
# simple one-liner info, which most builtin Modules will use
|
|
if len(extra_lines) == 1 and not child_lines:
|
|
main_str += extra_lines[0]
|
|
else:
|
|
main_str += '\n ' + '\n '.join(lines) + '\n'
|
|
|
|
main_str += ')'
|
|
return main_str
|
|
|
|
def __dir__(self):
|
|
module_attrs = dir(self.__class__)
|
|
attrs = list(self.__dict__.keys())
|
|
parameters = list(self._parameters.keys())
|
|
modules = list(self._modules.keys())
|
|
buffers = list(self._buffers.keys())
|
|
keys = module_attrs + attrs + parameters + modules + buffers
|
|
|
|
# Eliminate attrs that are not legal Python variable names
|
|
keys = [key for key in keys if not key[0].isdigit()]
|
|
|
|
return sorted(keys)
|