mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Summary: The documentation rendering for nn.Module.named_modules() is a bit broken, see the description of the last argument [here](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.named_modules). This PR fixes that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/70491 Reviewed By: mikaylagawarecki Differential Revision: D33349882 Pulled By: albanD fbshipit-source-id: a46327c12e8114f7ef2055a8518c4ca9d186e669
		
			
				
	
	
		
			1862 lines
		
	
	
		
			77 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1862 lines
		
	
	
		
			77 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from collections import OrderedDict, namedtuple
 | 
						|
import itertools
 | 
						|
import warnings
 | 
						|
import functools
 | 
						|
 | 
						|
import torch
 | 
						|
from ..parameter import Parameter
 | 
						|
import torch.utils.hooks as hooks
 | 
						|
 | 
						|
from torch import Tensor, device, dtype
 | 
						|
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
 | 
						|
from ...utils.hooks import RemovableHandle
 | 
						|
 | 
						|
_grad_t = Union[Tuple[Tensor, ...], Tensor]
 | 
						|
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
 | 
						|
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
 | 
						|
# the type of the subclass, not the looser type of `Module`.
 | 
						|
T = TypeVar('T', bound='Module')
 | 
						|
 | 
						|
class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
 | 
						|
    def __repr__(self):
 | 
						|
        if not self.missing_keys and not self.unexpected_keys:
 | 
						|
            return '<All keys matched successfully>'
 | 
						|
        return super(_IncompatibleKeys, self).__repr__()
 | 
						|
 | 
						|
    __str__ = __repr__
 | 
						|
 | 
						|
 | 
						|
def _addindent(s_, numSpaces):
 | 
						|
    s = s_.split('\n')
 | 
						|
    # don't do anything for single-line stuff
 | 
						|
    if len(s) == 1:
 | 
						|
        return s_
 | 
						|
    first = s.pop(0)
 | 
						|
    s = [(numSpaces * ' ') + line for line in s]
 | 
						|
    s = '\n'.join(s)
 | 
						|
    s = first + '\n' + s
 | 
						|
    return s
 | 
						|
 | 
						|
 | 
						|
r"""This tracks hooks common to all modules that are executed before/after
 | 
						|
calling forward and backward. This is global state used for debugging/profiling
 | 
						|
purposes"""
 | 
						|
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
_global_is_full_backward_hook: Optional[bool] = None
 | 
						|
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
 | 
						|
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
 | 
						|
 | 
						|
 | 
						|
def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
 | 
						|
    r"""Registers a forward pre-hook common to all modules.
 | 
						|
 | 
						|
    .. warning ::
 | 
						|
 | 
						|
        This adds global state to the `nn.module` module
 | 
						|
        and it is only intended for debugging/profiling purposes.
 | 
						|
 | 
						|
    The hook will be called every time before :func:`forward` is invoked.
 | 
						|
    It should have the following signature::
 | 
						|
 | 
						|
        hook(module, input) -> None or modified input
 | 
						|
 | 
						|
    The input contains only the positional arguments given to the module.
 | 
						|
    Keyword arguments won't be passed to the hooks and only to the ``forward``.
 | 
						|
    The hook can modify the input. User can either return a tuple or a
 | 
						|
    single modified value in the hook. We will wrap the value into a tuple
 | 
						|
    if a single value is returned(unless that value is already a tuple).
 | 
						|
 | 
						|
    This hook has precedence over the specific module hooks registered with
 | 
						|
    ``register_forward_pre_hook``.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
            a handle that can be used to remove the added hook by calling
 | 
						|
            ``handle.remove()``
 | 
						|
    """
 | 
						|
    handle = hooks.RemovableHandle(_global_forward_pre_hooks)
 | 
						|
    _global_forward_pre_hooks[handle.id] = hook
 | 
						|
    return handle
 | 
						|
 | 
						|
 | 
						|
def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
 | 
						|
    r"""Registers a global forward hook for all the modules
 | 
						|
 | 
						|
    .. warning ::
 | 
						|
 | 
						|
        This adds global state to the `nn.module` module
 | 
						|
        and it is only intended for debugging/profiling purposes.
 | 
						|
 | 
						|
    The hook will be called every time after :func:`forward` has computed an output.
 | 
						|
    It should have the following signature::
 | 
						|
 | 
						|
        hook(module, input, output) -> None or modified output
 | 
						|
 | 
						|
    The input contains only the positional arguments given to the module.
 | 
						|
    Keyword arguments won't be passed to the hooks and only to the ``forward``.
 | 
						|
    The hook can modify the output. It can modify the input inplace but
 | 
						|
    it will not have effect on forward since this is called after
 | 
						|
    :func:`forward` is called.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
            a handle that can be used to remove the added hook by calling
 | 
						|
            ``handle.remove()``
 | 
						|
 | 
						|
    This hook will be executed before specific module hooks registered with
 | 
						|
    ``register_forward_hook``.
 | 
						|
    """
 | 
						|
    handle = hooks.RemovableHandle(_global_forward_hooks)
 | 
						|
    _global_forward_hooks[handle.id] = hook
 | 
						|
    return handle
 | 
						|
 | 
						|
def register_module_backward_hook(
 | 
						|
    hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
 | 
						|
) -> RemovableHandle:
 | 
						|
    r"""Registers a backward hook common to all the modules.
 | 
						|
 | 
						|
    This function is deprecated in favor of
 | 
						|
    :func:`torch.nn.modules.module.register_module_full_backward_hook`
 | 
						|
    and the behavior of this function will change in future versions.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
            a handle that can be used to remove the added hook by calling
 | 
						|
            ``handle.remove()``
 | 
						|
 | 
						|
    """
 | 
						|
    global _global_is_full_backward_hook
 | 
						|
    if _global_is_full_backward_hook is True:
 | 
						|
        raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
 | 
						|
                           "global Module hook. Please use only one of them.")
 | 
						|
 | 
						|
    _global_is_full_backward_hook = False
 | 
						|
 | 
						|
    handle = hooks.RemovableHandle(_global_backward_hooks)
 | 
						|
    _global_backward_hooks[handle.id] = hook
 | 
						|
    return handle
 | 
						|
 | 
						|
def register_module_full_backward_hook(
 | 
						|
    hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
 | 
						|
) -> RemovableHandle:
 | 
						|
    r"""Registers a backward hook common to all the modules.
 | 
						|
 | 
						|
    .. warning ::
 | 
						|
        This adds global state to the `nn.module` module
 | 
						|
        and it is only intended for debugging/profiling purposes.
 | 
						|
 | 
						|
    The hook will be called every time the gradients with respect to module
 | 
						|
    inputs are computed. The hook should have the following signature::
 | 
						|
 | 
						|
        hook(module, grad_input, grad_output) -> Tensor or None
 | 
						|
 | 
						|
    The :attr:`grad_input` and :attr:`grad_output` are tuples. The hook should
 | 
						|
    not modify its arguments, but it can optionally return a new gradient with
 | 
						|
    respect to the input that will be used in place of :attr:`grad_input` in
 | 
						|
    subsequent computations. :attr:`grad_input` will only correspond to the inputs given
 | 
						|
    as positional arguments and all kwarg arguments will not appear in the hook. Entries
 | 
						|
    in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
 | 
						|
    arguments.
 | 
						|
 | 
						|
    For technical reasons, when this hook is applied to a Module, its forward function will
 | 
						|
    receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
 | 
						|
    of each Tensor returned by the Module's forward function.
 | 
						|
 | 
						|
    Global hooks are called before hooks registered with `register_backward_hook`
 | 
						|
 | 
						|
    Returns:
 | 
						|
        :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
            a handle that can be used to remove the added hook by calling
 | 
						|
            ``handle.remove()``
 | 
						|
 | 
						|
    """
 | 
						|
    global _global_is_full_backward_hook
 | 
						|
    if _global_is_full_backward_hook is False:
 | 
						|
        raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a "
 | 
						|
                           "global Module hook. Please use only one of them.")
 | 
						|
 | 
						|
    _global_is_full_backward_hook = True
 | 
						|
 | 
						|
    handle = hooks.RemovableHandle(_global_backward_hooks)
 | 
						|
    _global_backward_hooks[handle.id] = hook
 | 
						|
    return handle
 | 
						|
 | 
						|
 | 
						|
# Trick mypy into not applying contravariance rules to inputs by defining
 | 
						|
# forward as a value, rather than a function.  See also
 | 
						|
# https://github.com/python/mypy/issues/8795
 | 
						|
def _forward_unimplemented(self, *input: Any) -> None:
 | 
						|
    r"""Defines the computation performed at every call.
 | 
						|
 | 
						|
    Should be overridden by all subclasses.
 | 
						|
 | 
						|
    .. note::
 | 
						|
        Although the recipe for forward pass needs to be defined within
 | 
						|
        this function, one should call the :class:`Module` instance afterwards
 | 
						|
        instead of this since the former takes care of running the
 | 
						|
        registered hooks while the latter silently ignores them.
 | 
						|
    """
 | 
						|
    raise NotImplementedError
 | 
						|
 | 
						|
 | 
						|
class Module:
 | 
						|
    r"""Base class for all neural network modules.
 | 
						|
 | 
						|
    Your models should also subclass this class.
 | 
						|
 | 
						|
    Modules can also contain other Modules, allowing to nest them in
 | 
						|
    a tree structure. You can assign the submodules as regular attributes::
 | 
						|
 | 
						|
        import torch.nn as nn
 | 
						|
        import torch.nn.functional as F
 | 
						|
 | 
						|
        class Model(nn.Module):
 | 
						|
            def __init__(self):
 | 
						|
                super().__init__()
 | 
						|
                self.conv1 = nn.Conv2d(1, 20, 5)
 | 
						|
                self.conv2 = nn.Conv2d(20, 20, 5)
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                x = F.relu(self.conv1(x))
 | 
						|
                return F.relu(self.conv2(x))
 | 
						|
 | 
						|
    Submodules assigned in this way will be registered, and will have their
 | 
						|
    parameters converted too when you call :meth:`to`, etc.
 | 
						|
 | 
						|
    .. note::
 | 
						|
        As per the example above, an ``__init__()`` call to the parent class
 | 
						|
        must be made before assignment on the child.
 | 
						|
 | 
						|
    :ivar training: Boolean represents whether this module is in training or
 | 
						|
                    evaluation mode.
 | 
						|
    :vartype training: bool
 | 
						|
    """
 | 
						|
 | 
						|
    dump_patches: bool = False
 | 
						|
 | 
						|
    r"""This allows better BC support for :meth:`load_state_dict`. In
 | 
						|
    :meth:`state_dict`, the version number will be saved as in the attribute
 | 
						|
    `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
 | 
						|
    dictionary with keys that follow the naming convention of state dict. See
 | 
						|
    ``_load_from_state_dict`` on how to use this information in loading.
 | 
						|
 | 
						|
    If new parameters/buffers are added/removed from a module, this number shall
 | 
						|
    be bumped, and the module's `_load_from_state_dict` method can compare the
 | 
						|
    version number and do appropriate changes if the state dict is from before
 | 
						|
    the change."""
 | 
						|
    _version: int = 1
 | 
						|
 | 
						|
    training: bool
 | 
						|
    _is_full_backward_hook: Optional[bool]
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        """
 | 
						|
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
 | 
						|
        """
 | 
						|
        torch._C._log_api_usage_once("python.nn_module")
 | 
						|
 | 
						|
        self.training = True
 | 
						|
        self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
 | 
						|
        self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
 | 
						|
        self._non_persistent_buffers_set: Set[str] = set()
 | 
						|
        self._backward_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
        self._is_full_backward_hook = None
 | 
						|
        self._forward_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
        self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
        self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
        self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
 | 
						|
        self._modules: Dict[str, Optional['Module']] = OrderedDict()
 | 
						|
 | 
						|
    forward: Callable[..., Any] = _forward_unimplemented
 | 
						|
 | 
						|
    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
 | 
						|
        r"""Adds a buffer to the module.
 | 
						|
 | 
						|
        This is typically used to register a buffer that should not to be
 | 
						|
        considered a model parameter. For example, BatchNorm's ``running_mean``
 | 
						|
        is not a parameter, but is part of the module's state. Buffers, by
 | 
						|
        default, are persistent and will be saved alongside parameters. This
 | 
						|
        behavior can be changed by setting :attr:`persistent` to ``False``. The
 | 
						|
        only difference between a persistent buffer and a non-persistent buffer
 | 
						|
        is that the latter will not be a part of this module's
 | 
						|
        :attr:`state_dict`.
 | 
						|
 | 
						|
        Buffers can be accessed as attributes using given names.
 | 
						|
 | 
						|
        Args:
 | 
						|
            name (string): name of the buffer. The buffer can be accessed
 | 
						|
                from this module using the given name
 | 
						|
            tensor (Tensor or None): buffer to be registered. If ``None``, then operations
 | 
						|
                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
 | 
						|
                the buffer is **not** included in the module's :attr:`state_dict`.
 | 
						|
            persistent (bool): whether the buffer is part of this module's
 | 
						|
                :attr:`state_dict`.
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> self.register_buffer('running_mean', torch.zeros(num_features))
 | 
						|
 | 
						|
        """
 | 
						|
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
 | 
						|
            raise RuntimeError("ScriptModule does not support non-persistent buffers")
 | 
						|
 | 
						|
        if '_buffers' not in self.__dict__:
 | 
						|
            raise AttributeError(
 | 
						|
                "cannot assign buffer before Module.__init__() call")
 | 
						|
        elif not isinstance(name, torch._six.string_classes):
 | 
						|
            raise TypeError("buffer name should be a string. "
 | 
						|
                            "Got {}".format(torch.typename(name)))
 | 
						|
        elif '.' in name:
 | 
						|
            raise KeyError("buffer name can't contain \".\"")
 | 
						|
        elif name == '':
 | 
						|
            raise KeyError("buffer name can't be empty string \"\"")
 | 
						|
        elif hasattr(self, name) and name not in self._buffers:
 | 
						|
            raise KeyError("attribute '{}' already exists".format(name))
 | 
						|
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
 | 
						|
            raise TypeError("cannot assign '{}' object to buffer '{}' "
 | 
						|
                            "(torch Tensor or None required)"
 | 
						|
                            .format(torch.typename(tensor), name))
 | 
						|
        else:
 | 
						|
            self._buffers[name] = tensor
 | 
						|
            if persistent:
 | 
						|
                self._non_persistent_buffers_set.discard(name)
 | 
						|
            else:
 | 
						|
                self._non_persistent_buffers_set.add(name)
 | 
						|
 | 
						|
    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
 | 
						|
        r"""Adds a parameter to the module.
 | 
						|
 | 
						|
        The parameter can be accessed as an attribute using given name.
 | 
						|
 | 
						|
        Args:
 | 
						|
            name (string): name of the parameter. The parameter can be accessed
 | 
						|
                from this module using the given name
 | 
						|
            param (Parameter or None): parameter to be added to the module. If
 | 
						|
                ``None``, then operations that run on parameters, such as :attr:`cuda`,
 | 
						|
                are ignored. If ``None``, the parameter is **not** included in the
 | 
						|
                module's :attr:`state_dict`.
 | 
						|
        """
 | 
						|
        if '_parameters' not in self.__dict__:
 | 
						|
            raise AttributeError(
 | 
						|
                "cannot assign parameter before Module.__init__() call")
 | 
						|
 | 
						|
        elif not isinstance(name, torch._six.string_classes):
 | 
						|
            raise TypeError("parameter name should be a string. "
 | 
						|
                            "Got {}".format(torch.typename(name)))
 | 
						|
        elif '.' in name:
 | 
						|
            raise KeyError("parameter name can't contain \".\"")
 | 
						|
        elif name == '':
 | 
						|
            raise KeyError("parameter name can't be empty string \"\"")
 | 
						|
        elif hasattr(self, name) and name not in self._parameters:
 | 
						|
            raise KeyError("attribute '{}' already exists".format(name))
 | 
						|
 | 
						|
        if param is None:
 | 
						|
            self._parameters[name] = None
 | 
						|
        elif not isinstance(param, Parameter):
 | 
						|
            raise TypeError("cannot assign '{}' object to parameter '{}' "
 | 
						|
                            "(torch.nn.Parameter or None required)"
 | 
						|
                            .format(torch.typename(param), name))
 | 
						|
        elif param.grad_fn:
 | 
						|
            raise ValueError(
 | 
						|
                "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
 | 
						|
                "parameters must be created explicitly. To express '{0}' "
 | 
						|
                "as a function of another Tensor, compute the value in "
 | 
						|
                "the forward() method.".format(name))
 | 
						|
        else:
 | 
						|
            self._parameters[name] = param
 | 
						|
 | 
						|
    def add_module(self, name: str, module: Optional['Module']) -> None:
 | 
						|
        r"""Adds a child module to the current module.
 | 
						|
 | 
						|
        The module can be accessed as an attribute using the given name.
 | 
						|
 | 
						|
        Args:
 | 
						|
            name (string): name of the child module. The child module can be
 | 
						|
                accessed from this module using the given name
 | 
						|
            module (Module): child module to be added to the module.
 | 
						|
        """
 | 
						|
        if not isinstance(module, Module) and module is not None:
 | 
						|
            raise TypeError("{} is not a Module subclass".format(
 | 
						|
                torch.typename(module)))
 | 
						|
        elif not isinstance(name, torch._six.string_classes):
 | 
						|
            raise TypeError("module name should be a string. Got {}".format(
 | 
						|
                torch.typename(name)))
 | 
						|
        elif hasattr(self, name) and name not in self._modules:
 | 
						|
            raise KeyError("attribute '{}' already exists".format(name))
 | 
						|
        elif '.' in name:
 | 
						|
            raise KeyError("module name can't contain \".\", got: {}".format(name))
 | 
						|
        elif name == '':
 | 
						|
            raise KeyError("module name can't be empty string \"\"")
 | 
						|
        self._modules[name] = module
 | 
						|
 | 
						|
    def register_module(self, name: str, module: Optional['Module']) -> None:
 | 
						|
        r"""Alias for :func:`add_module`."""
 | 
						|
        self.add_module(name, module)
 | 
						|
 | 
						|
    def get_submodule(self, target: str) -> "Module":
 | 
						|
        """
 | 
						|
        Returns the submodule given by ``target`` if it exists,
 | 
						|
        otherwise throws an error.
 | 
						|
 | 
						|
        For example, let's say you have an ``nn.Module`` ``A`` that
 | 
						|
        looks like this:
 | 
						|
 | 
						|
        .. code-block::text
 | 
						|
 | 
						|
            A(
 | 
						|
                (net_b): Module(
 | 
						|
                    (net_c): Module(
 | 
						|
                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
 | 
						|
                    )
 | 
						|
                    (linear): Linear(in_features=100, out_features=200, bias=True)
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
 | 
						|
        submodule ``net_b``, which itself has two submodules ``net_c``
 | 
						|
        and ``linear``. ``net_c`` then has a submodule ``conv``.)
 | 
						|
 | 
						|
        To check whether or not we have the ``linear`` submodule, we
 | 
						|
        would call ``get_submodule("net_b.linear")``. To check whether
 | 
						|
        we have the ``conv`` submodule, we would call
 | 
						|
        ``get_submodule("net_b.net_c.conv")``.
 | 
						|
 | 
						|
        The runtime of ``get_submodule`` is bounded by the degree
 | 
						|
        of module nesting in ``target``. A query against
 | 
						|
        ``named_modules`` achieves the same result, but it is O(N) in
 | 
						|
        the number of transitive modules. So, for a simple check to see
 | 
						|
        if some submodule exists, ``get_submodule`` should always be
 | 
						|
        used.
 | 
						|
 | 
						|
        Args:
 | 
						|
            target: The fully-qualified string name of the submodule
 | 
						|
                to look for. (See above example for how to specify a
 | 
						|
                fully-qualified string.)
 | 
						|
 | 
						|
        Returns:
 | 
						|
            torch.nn.Module: The submodule referenced by ``target``
 | 
						|
 | 
						|
        Raises:
 | 
						|
            AttributeError: If the target string references an invalid
 | 
						|
                path or resolves to something that is not an
 | 
						|
                ``nn.Module``
 | 
						|
        """
 | 
						|
        if target == "":
 | 
						|
            return self
 | 
						|
 | 
						|
        atoms: List[str] = target.split(".")
 | 
						|
        mod: torch.nn.Module = self
 | 
						|
 | 
						|
        for item in atoms:
 | 
						|
 | 
						|
            if not hasattr(mod, item):
 | 
						|
                raise AttributeError(mod._get_name() + " has no "
 | 
						|
                                     "attribute `" + item + "`")
 | 
						|
 | 
						|
            mod = getattr(mod, item)
 | 
						|
 | 
						|
            if not isinstance(mod, torch.nn.Module):
 | 
						|
                raise AttributeError("`" + item + "` is not "
 | 
						|
                                     "an nn.Module")
 | 
						|
 | 
						|
        return mod
 | 
						|
 | 
						|
    def get_parameter(self, target: str) -> "Parameter":
 | 
						|
        """
 | 
						|
        Returns the parameter given by ``target`` if it exists,
 | 
						|
        otherwise throws an error.
 | 
						|
 | 
						|
        See the docstring for ``get_submodule`` for a more detailed
 | 
						|
        explanation of this method's functionality as well as how to
 | 
						|
        correctly specify ``target``.
 | 
						|
 | 
						|
        Args:
 | 
						|
            target: The fully-qualified string name of the Parameter
 | 
						|
                to look for. (See ``get_submodule`` for how to specify a
 | 
						|
                fully-qualified string.)
 | 
						|
 | 
						|
        Returns:
 | 
						|
            torch.nn.Parameter: The Parameter referenced by ``target``
 | 
						|
 | 
						|
        Raises:
 | 
						|
            AttributeError: If the target string references an invalid
 | 
						|
                path or resolves to something that is not an
 | 
						|
                ``nn.Parameter``
 | 
						|
        """
 | 
						|
        module_path, _, param_name = target.rpartition(".")
 | 
						|
 | 
						|
        mod: torch.nn.Module = self.get_submodule(module_path)
 | 
						|
 | 
						|
        if not hasattr(mod, param_name):
 | 
						|
            raise AttributeError(mod._get_name() + " has no attribute `"
 | 
						|
                                 + param_name + "`")
 | 
						|
 | 
						|
        param: torch.nn.Parameter = getattr(mod, param_name)
 | 
						|
 | 
						|
        if not isinstance(param, torch.nn.Parameter):
 | 
						|
            raise AttributeError("`" + param_name + "` is not an "
 | 
						|
                                 "nn.Parameter")
 | 
						|
 | 
						|
        return param
 | 
						|
 | 
						|
    def get_buffer(self, target: str) -> "Tensor":
 | 
						|
        """
 | 
						|
        Returns the buffer given by ``target`` if it exists,
 | 
						|
        otherwise throws an error.
 | 
						|
 | 
						|
        See the docstring for ``get_submodule`` for a more detailed
 | 
						|
        explanation of this method's functionality as well as how to
 | 
						|
        correctly specify ``target``.
 | 
						|
 | 
						|
        Args:
 | 
						|
            target: The fully-qualified string name of the buffer
 | 
						|
                to look for. (See ``get_submodule`` for how to specify a
 | 
						|
                fully-qualified string.)
 | 
						|
 | 
						|
        Returns:
 | 
						|
            torch.Tensor: The buffer referenced by ``target``
 | 
						|
 | 
						|
        Raises:
 | 
						|
            AttributeError: If the target string references an invalid
 | 
						|
                path or resolves to something that is not a
 | 
						|
                buffer
 | 
						|
        """
 | 
						|
        module_path, _, buffer_name = target.rpartition(".")
 | 
						|
 | 
						|
        mod: torch.nn.Module = self.get_submodule(module_path)
 | 
						|
 | 
						|
        if not hasattr(mod, buffer_name):
 | 
						|
            raise AttributeError(mod._get_name() + " has no attribute `"
 | 
						|
                                 + buffer_name + "`")
 | 
						|
 | 
						|
        buffer: torch.Tensor = getattr(mod, buffer_name)
 | 
						|
 | 
						|
        if buffer_name not in mod._buffers:
 | 
						|
            raise AttributeError("`" + buffer_name + "` is not a buffer")
 | 
						|
 | 
						|
        return buffer
 | 
						|
 | 
						|
    def get_extra_state(self) -> Any:
 | 
						|
        """
 | 
						|
        Returns any extra state to include in the module's state_dict.
 | 
						|
        Implement this and a corresponding :func:`set_extra_state` for your module
 | 
						|
        if you need to store extra state. This function is called when building the
 | 
						|
        module's `state_dict()`.
 | 
						|
 | 
						|
        Note that extra state should be pickleable to ensure working serialization
 | 
						|
        of the state_dict. We only provide provide backwards compatibility guarantees
 | 
						|
        for serializing Tensors; other objects may break backwards compatibility if
 | 
						|
        their serialized pickled form changes.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            object: Any extra state to store in the module's state_dict
 | 
						|
        """
 | 
						|
        raise RuntimeError(
 | 
						|
            "Reached a code path in Module.get_extra_state() that should never be called. "
 | 
						|
            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
 | 
						|
            "to report this bug.")
 | 
						|
 | 
						|
    def set_extra_state(self, state: Any):
 | 
						|
        """
 | 
						|
        This function is called from :func:`load_state_dict` to handle any extra state
 | 
						|
        found within the `state_dict`. Implement this function and a corresponding
 | 
						|
        :func:`get_extra_state` for your module if you need to store extra state within its
 | 
						|
        `state_dict`.
 | 
						|
 | 
						|
        Args:
 | 
						|
            state (dict): Extra state from the `state_dict`
 | 
						|
        """
 | 
						|
        raise RuntimeError(
 | 
						|
            "Reached a code path in Module.set_extra_state() that should never be called. "
 | 
						|
            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
 | 
						|
            "to report this bug.")
 | 
						|
 | 
						|
    def _apply(self, fn):
 | 
						|
        for module in self.children():
 | 
						|
            module._apply(fn)
 | 
						|
 | 
						|
        def compute_should_use_set_data(tensor, tensor_applied):
 | 
						|
            if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
 | 
						|
                # If the new tensor has compatible tensor type as the existing tensor,
 | 
						|
                # the current behavior is to change the tensor in-place using `.data =`,
 | 
						|
                # and the future behavior is to overwrite the existing tensor. However,
 | 
						|
                # changing the current behavior is a BC-breaking change, and we want it
 | 
						|
                # to happen in future releases. So for now we introduce the
 | 
						|
                # `torch.__future__.get_overwrite_module_params_on_conversion()`
 | 
						|
                # global flag to let the user control whether they want the future
 | 
						|
                # behavior of overwriting the existing tensor or not.
 | 
						|
                return not torch.__future__.get_overwrite_module_params_on_conversion()
 | 
						|
            else:
 | 
						|
                return False
 | 
						|
 | 
						|
        for key, param in self._parameters.items():
 | 
						|
            if param is None:
 | 
						|
                continue
 | 
						|
            # Tensors stored in modules are graph leaves, and we don't want to
 | 
						|
            # track autograd history of `param_applied`, so we have to use
 | 
						|
            # `with torch.no_grad():`
 | 
						|
            with torch.no_grad():
 | 
						|
                param_applied = fn(param)
 | 
						|
            should_use_set_data = compute_should_use_set_data(param, param_applied)
 | 
						|
            if should_use_set_data:
 | 
						|
                param.data = param_applied
 | 
						|
                out_param = param
 | 
						|
            else:
 | 
						|
                assert isinstance(param, Parameter)
 | 
						|
                assert param.is_leaf
 | 
						|
                out_param = Parameter(param_applied, param.requires_grad)
 | 
						|
                self._parameters[key] = out_param
 | 
						|
 | 
						|
            if param.grad is not None:
 | 
						|
                with torch.no_grad():
 | 
						|
                    grad_applied = fn(param.grad)
 | 
						|
                should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
 | 
						|
                if should_use_set_data:
 | 
						|
                    out_param.grad.data = grad_applied
 | 
						|
                else:
 | 
						|
                    assert param.grad.is_leaf
 | 
						|
                    out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad)
 | 
						|
 | 
						|
        for key, buf in self._buffers.items():
 | 
						|
            if buf is not None:
 | 
						|
                self._buffers[key] = fn(buf)
 | 
						|
 | 
						|
        return self
 | 
						|
 | 
						|
    def apply(self: T, fn: Callable[['Module'], None]) -> T:
 | 
						|
        r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
 | 
						|
        as well as self. Typical use includes initializing the parameters of a model
 | 
						|
        (see also :ref:`nn-init-doc`).
 | 
						|
 | 
						|
        Args:
 | 
						|
            fn (:class:`Module` -> None): function to be applied to each submodule
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> @torch.no_grad()
 | 
						|
            >>> def init_weights(m):
 | 
						|
            >>>     print(m)
 | 
						|
            >>>     if type(m) == nn.Linear:
 | 
						|
            >>>         m.weight.fill_(1.0)
 | 
						|
            >>>         print(m.weight)
 | 
						|
            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
 | 
						|
            >>> net.apply(init_weights)
 | 
						|
            Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 1.,  1.],
 | 
						|
                    [ 1.,  1.]])
 | 
						|
            Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 1.,  1.],
 | 
						|
                    [ 1.,  1.]])
 | 
						|
            Sequential(
 | 
						|
              (0): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
              (1): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            )
 | 
						|
            Sequential(
 | 
						|
              (0): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
              (1): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            )
 | 
						|
        """
 | 
						|
        for module in self.children():
 | 
						|
            module.apply(fn)
 | 
						|
        fn(self)
 | 
						|
        return self
 | 
						|
 | 
						|
    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
 | 
						|
        r"""Moves all model parameters and buffers to the GPU.
 | 
						|
 | 
						|
        This also makes associated parameters and buffers different objects. So
 | 
						|
        it should be called before constructing optimizer if the module will
 | 
						|
        live on GPU while being optimized.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Args:
 | 
						|
            device (int, optional): if specified, all parameters will be
 | 
						|
                copied to that device
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.cuda(device))
 | 
						|
 | 
						|
    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:
 | 
						|
        r"""Moves all model parameters and buffers to the XPU.
 | 
						|
 | 
						|
        This also makes associated parameters and buffers different objects. So
 | 
						|
        it should be called before constructing optimizer if the module will
 | 
						|
        live on XPU while being optimized.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Arguments:
 | 
						|
            device (int, optional): if specified, all parameters will be
 | 
						|
                copied to that device
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.xpu(device))
 | 
						|
 | 
						|
    def cpu(self: T) -> T:
 | 
						|
        r"""Moves all model parameters and buffers to the CPU.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.cpu())
 | 
						|
 | 
						|
    def type(self: T, dst_type: Union[dtype, str]) -> T:
 | 
						|
        r"""Casts all parameters and buffers to :attr:`dst_type`.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Args:
 | 
						|
            dst_type (type or string): the desired type
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.type(dst_type))
 | 
						|
 | 
						|
    def float(self: T) -> T:
 | 
						|
        r"""Casts all floating point parameters and buffers to ``float`` datatype.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.float() if t.is_floating_point() else t)
 | 
						|
 | 
						|
    def double(self: T) -> T:
 | 
						|
        r"""Casts all floating point parameters and buffers to ``double`` datatype.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.double() if t.is_floating_point() else t)
 | 
						|
 | 
						|
    def half(self: T) -> T:
 | 
						|
        r"""Casts all floating point parameters and buffers to ``half`` datatype.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.half() if t.is_floating_point() else t)
 | 
						|
 | 
						|
    def bfloat16(self: T) -> T:
 | 
						|
        r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
 | 
						|
 | 
						|
    def to_empty(self: T, *, device: Union[str, device]) -> T:
 | 
						|
        r"""Moves the parameters and buffers to the specified device without copying storage.
 | 
						|
 | 
						|
        Args:
 | 
						|
            device (:class:`torch.device`): The desired device of the parameters
 | 
						|
                and buffers in this module.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self._apply(lambda t: torch.empty_like(t, device=device))
 | 
						|
 | 
						|
    @overload
 | 
						|
    def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
 | 
						|
           non_blocking: bool = ...) -> T:
 | 
						|
        ...
 | 
						|
 | 
						|
    @overload
 | 
						|
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
 | 
						|
        ...
 | 
						|
 | 
						|
    @overload
 | 
						|
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
 | 
						|
        ...
 | 
						|
 | 
						|
    def to(self, *args, **kwargs):
 | 
						|
        r"""Moves and/or casts the parameters and buffers.
 | 
						|
 | 
						|
        This can be called as
 | 
						|
 | 
						|
        .. function:: to(device=None, dtype=None, non_blocking=False)
 | 
						|
           :noindex:
 | 
						|
 | 
						|
        .. function:: to(dtype, non_blocking=False)
 | 
						|
           :noindex:
 | 
						|
 | 
						|
        .. function:: to(tensor, non_blocking=False)
 | 
						|
           :noindex:
 | 
						|
 | 
						|
        .. function:: to(memory_format=torch.channels_last)
 | 
						|
           :noindex:
 | 
						|
 | 
						|
        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
 | 
						|
        floating point or complex :attr:`dtype`\ s. In addition, this method will
 | 
						|
        only cast the floating point or complex parameters and buffers to :attr:`dtype`
 | 
						|
        (if given). The integral parameters and buffers will be moved
 | 
						|
        :attr:`device`, if that is given, but with dtypes unchanged. When
 | 
						|
        :attr:`non_blocking` is set, it tries to convert/move asynchronously
 | 
						|
        with respect to the host if possible, e.g., moving CPU Tensors with
 | 
						|
        pinned memory to CUDA devices.
 | 
						|
 | 
						|
        See below for examples.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            This method modifies the module in-place.
 | 
						|
 | 
						|
        Args:
 | 
						|
            device (:class:`torch.device`): the desired device of the parameters
 | 
						|
                and buffers in this module
 | 
						|
            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
 | 
						|
                the parameters and buffers in this module
 | 
						|
            tensor (torch.Tensor): Tensor whose dtype and device are the desired
 | 
						|
                dtype and device for all parameters and buffers in this module
 | 
						|
            memory_format (:class:`torch.memory_format`): the desired memory
 | 
						|
                format for 4D parameters and buffers in this module (keyword
 | 
						|
                only argument)
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
 | 
						|
        Examples::
 | 
						|
 | 
						|
            >>> linear = nn.Linear(2, 2)
 | 
						|
            >>> linear.weight
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 0.1913, -0.3420],
 | 
						|
                    [-0.5113, -0.2325]])
 | 
						|
            >>> linear.to(torch.double)
 | 
						|
            Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            >>> linear.weight
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 0.1913, -0.3420],
 | 
						|
                    [-0.5113, -0.2325]], dtype=torch.float64)
 | 
						|
            >>> gpu1 = torch.device("cuda:1")
 | 
						|
            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
 | 
						|
            Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            >>> linear.weight
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 0.1914, -0.3420],
 | 
						|
                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
 | 
						|
            >>> cpu = torch.device("cpu")
 | 
						|
            >>> linear.to(cpu)
 | 
						|
            Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            >>> linear.weight
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 0.1914, -0.3420],
 | 
						|
                    [-0.5112, -0.2324]], dtype=torch.float16)
 | 
						|
 | 
						|
            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
 | 
						|
            >>> linear.weight
 | 
						|
            Parameter containing:
 | 
						|
            tensor([[ 0.3741+0.j,  0.2382+0.j],
 | 
						|
                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
 | 
						|
            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))
 | 
						|
            tensor([[0.6122+0.j, 0.1150+0.j],
 | 
						|
                    [0.6122+0.j, 0.1150+0.j],
 | 
						|
                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
 | 
						|
 | 
						|
        """
 | 
						|
 | 
						|
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 | 
						|
 | 
						|
        if dtype is not None:
 | 
						|
            if not (dtype.is_floating_point or dtype.is_complex):
 | 
						|
                raise TypeError('nn.Module.to only accepts floating point or complex '
 | 
						|
                                'dtypes, but got desired dtype={}'.format(dtype))
 | 
						|
            if dtype.is_complex:
 | 
						|
                warnings.warn(
 | 
						|
                    "Complex modules are a new feature under active development whose design may change, "
 | 
						|
                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "
 | 
						|
                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md "
 | 
						|
                    "if a complex module does not work as expected.")
 | 
						|
 | 
						|
        def convert(t):
 | 
						|
            if convert_to_format is not None and t.dim() in (4, 5):
 | 
						|
                return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
 | 
						|
                            non_blocking, memory_format=convert_to_format)
 | 
						|
            return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
 | 
						|
 | 
						|
        return self._apply(convert)
 | 
						|
 | 
						|
    def register_backward_hook(
 | 
						|
        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
 | 
						|
    ) -> RemovableHandle:
 | 
						|
        r"""Registers a backward hook on the module.
 | 
						|
 | 
						|
        This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and
 | 
						|
        the behavior of this function will change in future versions.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
                a handle that can be used to remove the added hook by calling
 | 
						|
                ``handle.remove()``
 | 
						|
 | 
						|
        """
 | 
						|
        if self._is_full_backward_hook is True:
 | 
						|
            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
 | 
						|
                               "single Module. Please use only one of them.")
 | 
						|
 | 
						|
        self._is_full_backward_hook = False
 | 
						|
 | 
						|
        handle = hooks.RemovableHandle(self._backward_hooks)
 | 
						|
        self._backward_hooks[handle.id] = hook
 | 
						|
        return handle
 | 
						|
 | 
						|
    def register_full_backward_hook(
 | 
						|
        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
 | 
						|
    ) -> RemovableHandle:
 | 
						|
        r"""Registers a backward hook on the module.
 | 
						|
 | 
						|
        The hook will be called every time the gradients with respect to module
 | 
						|
        inputs are computed. The hook should have the following signature::
 | 
						|
 | 
						|
            hook(module, grad_input, grad_output) -> tuple(Tensor) or None
 | 
						|
 | 
						|
        The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients
 | 
						|
        with respect to the inputs and outputs respectively. The hook should
 | 
						|
        not modify its arguments, but it can optionally return a new gradient with
 | 
						|
        respect to the input that will be used in place of :attr:`grad_input` in
 | 
						|
        subsequent computations. :attr:`grad_input` will only correspond to the inputs given
 | 
						|
        as positional arguments and all kwarg arguments are ignored. Entries
 | 
						|
        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
 | 
						|
        arguments.
 | 
						|
 | 
						|
        For technical reasons, when this hook is applied to a Module, its forward function will
 | 
						|
        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
 | 
						|
        of each Tensor returned by the Module's forward function.
 | 
						|
 | 
						|
        .. warning ::
 | 
						|
            Modifying inputs or outputs inplace is not allowed when using backward hooks and
 | 
						|
            will raise an error.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
                a handle that can be used to remove the added hook by calling
 | 
						|
                ``handle.remove()``
 | 
						|
 | 
						|
        """
 | 
						|
        if self._is_full_backward_hook is False:
 | 
						|
            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
 | 
						|
                               "single Module. Please use only one of them.")
 | 
						|
 | 
						|
        self._is_full_backward_hook = True
 | 
						|
 | 
						|
        handle = hooks.RemovableHandle(self._backward_hooks)
 | 
						|
        self._backward_hooks[handle.id] = hook
 | 
						|
        return handle
 | 
						|
 | 
						|
    def _get_backward_hooks(self):
 | 
						|
        r"""Returns the backward hooks for use in the call function.
 | 
						|
        It returns two lists, one with the full backward hooks and one with the non-full
 | 
						|
        backward hooks.
 | 
						|
        """
 | 
						|
        full_backward_hooks: List[Callable] = []
 | 
						|
        if (_global_is_full_backward_hook is True):
 | 
						|
            full_backward_hooks += _global_backward_hooks.values()
 | 
						|
        if (self._is_full_backward_hook is True):
 | 
						|
            full_backward_hooks += self._backward_hooks.values()
 | 
						|
 | 
						|
        non_full_backward_hooks: List[Callable] = []
 | 
						|
        if (_global_is_full_backward_hook is False):
 | 
						|
            non_full_backward_hooks += _global_backward_hooks.values()
 | 
						|
        if (self._is_full_backward_hook is False):
 | 
						|
            non_full_backward_hooks += self._backward_hooks.values()
 | 
						|
 | 
						|
        return full_backward_hooks, non_full_backward_hooks
 | 
						|
 | 
						|
    def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn):
 | 
						|
        if not isinstance(result, torch.Tensor):
 | 
						|
            if not (isinstance(result, tuple) and all([isinstance(r, torch.Tensor) for r in result])):
 | 
						|
                warnings.warn("Using non-full backward hooks on a Module that does not return a "
 | 
						|
                              "single Tensor or a tuple of Tensors is deprecated and will be removed "
 | 
						|
                              "in future versions. This hook will be missing some of the grad_output. "
 | 
						|
                              "Please use register_full_backward_hook to get the documented behavior.")
 | 
						|
                return
 | 
						|
        else:
 | 
						|
            result = (result,)
 | 
						|
 | 
						|
        if not isinstance(inputs, torch.Tensor):
 | 
						|
            if not (isinstance(inputs, tuple) and all([isinstance(i, torch.Tensor) for i in inputs])):
 | 
						|
                warnings.warn("Using non-full backward hooks on a Module that does not take as input a "
 | 
						|
                              "single Tensor or a tuple of Tensors is deprecated and will be removed "
 | 
						|
                              "in future versions. This hook will be missing some of the grad_input. "
 | 
						|
                              "Please use register_full_backward_hook to get the documented behavior.")
 | 
						|
                return
 | 
						|
        else:
 | 
						|
            inputs = (inputs,)
 | 
						|
 | 
						|
        # At this point we are sure that inputs and result are tuple of Tensors
 | 
						|
        out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None}
 | 
						|
        if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn):
 | 
						|
            warnings.warn("Using a non-full backward hook when outputs are nested in python data structure "
 | 
						|
                          "is deprecated and will be removed in future versions. This hook will be missing "
 | 
						|
                          "some grad_output.")
 | 
						|
        elif len(out_grad_fn) > 1:
 | 
						|
            warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes "
 | 
						|
                          "is deprecated and will be removed in future versions. This hook will be missing "
 | 
						|
                          "some grad_output. Please use register_full_backward_hook to get the documented behavior.")
 | 
						|
        else:
 | 
						|
            # At this point the grad_ouput part of the hook will most likely be correct
 | 
						|
            inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None}
 | 
						|
 | 
						|
            next_functions = {n[0] for n in grad_fn.next_functions}
 | 
						|
 | 
						|
            if inputs_grad_fn != next_functions:
 | 
						|
                warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
 | 
						|
                              "is deprecated and will be removed in future versions. This hook will be missing "
 | 
						|
                              "some grad_input. Please use register_full_backward_hook to get the documented "
 | 
						|
                              "behavior.")
 | 
						|
 | 
						|
    def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
 | 
						|
        r"""Registers a forward pre-hook on the module.
 | 
						|
 | 
						|
        The hook will be called every time before :func:`forward` is invoked.
 | 
						|
        It should have the following signature::
 | 
						|
 | 
						|
            hook(module, input) -> None or modified input
 | 
						|
 | 
						|
        The input contains only the positional arguments given to the module.
 | 
						|
        Keyword arguments won't be passed to the hooks and only to the ``forward``.
 | 
						|
        The hook can modify the input. User can either return a tuple or a
 | 
						|
        single modified value in the hook. We will wrap the value into a tuple
 | 
						|
        if a single value is returned(unless that value is already a tuple).
 | 
						|
 | 
						|
        Returns:
 | 
						|
            :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
                a handle that can be used to remove the added hook by calling
 | 
						|
                ``handle.remove()``
 | 
						|
        """
 | 
						|
        handle = hooks.RemovableHandle(self._forward_pre_hooks)
 | 
						|
        self._forward_pre_hooks[handle.id] = hook
 | 
						|
        return handle
 | 
						|
 | 
						|
    def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
 | 
						|
        r"""Registers a forward hook on the module.
 | 
						|
 | 
						|
        The hook will be called every time after :func:`forward` has computed an output.
 | 
						|
        It should have the following signature::
 | 
						|
 | 
						|
            hook(module, input, output) -> None or modified output
 | 
						|
 | 
						|
        The input contains only the positional arguments given to the module.
 | 
						|
        Keyword arguments won't be passed to the hooks and only to the ``forward``.
 | 
						|
        The hook can modify the output. It can modify the input inplace but
 | 
						|
        it will not have effect on forward since this is called after
 | 
						|
        :func:`forward` is called.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            :class:`torch.utils.hooks.RemovableHandle`:
 | 
						|
                a handle that can be used to remove the added hook by calling
 | 
						|
                ``handle.remove()``
 | 
						|
        """
 | 
						|
        handle = hooks.RemovableHandle(self._forward_hooks)
 | 
						|
        self._forward_hooks[handle.id] = hook
 | 
						|
        return handle
 | 
						|
 | 
						|
    def _slow_forward(self, *input, **kwargs):
 | 
						|
        tracing_state = torch._C._get_tracing_state()
 | 
						|
        if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod):
 | 
						|
            return self.forward(*input, **kwargs)
 | 
						|
        recording_scopes = torch.jit._trace._trace_module_map is not None
 | 
						|
        if recording_scopes:
 | 
						|
            # type ignore was added because at this point one knows that
 | 
						|
            # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
 | 
						|
            name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
 | 
						|
            if name:
 | 
						|
                tracing_state.push_scope(name)
 | 
						|
            else:
 | 
						|
                recording_scopes = False
 | 
						|
        try:
 | 
						|
            result = self.forward(*input, **kwargs)
 | 
						|
        finally:
 | 
						|
            if recording_scopes:
 | 
						|
                tracing_state.pop_scope()
 | 
						|
        return result
 | 
						|
 | 
						|
    def _call_impl(self, *input, **kwargs):
 | 
						|
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
 | 
						|
        # If we don't have any hooks, we want to skip the rest of the logic in
 | 
						|
        # this function, and just call forward.
 | 
						|
        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
 | 
						|
                or _global_forward_hooks or _global_forward_pre_hooks):
 | 
						|
            return forward_call(*input, **kwargs)
 | 
						|
        # Do not call functions when jit is used
 | 
						|
        full_backward_hooks, non_full_backward_hooks = [], []
 | 
						|
        if self._backward_hooks or _global_backward_hooks:
 | 
						|
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
 | 
						|
        if _global_forward_pre_hooks or self._forward_pre_hooks:
 | 
						|
            for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
 | 
						|
                result = hook(self, input)
 | 
						|
                if result is not None:
 | 
						|
                    if not isinstance(result, tuple):
 | 
						|
                        result = (result,)
 | 
						|
                    input = result
 | 
						|
 | 
						|
        bw_hook = None
 | 
						|
        if full_backward_hooks:
 | 
						|
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
 | 
						|
            input = bw_hook.setup_input_hook(input)
 | 
						|
 | 
						|
        result = forward_call(*input, **kwargs)
 | 
						|
        if _global_forward_hooks or self._forward_hooks:
 | 
						|
            for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
 | 
						|
                hook_result = hook(self, input, result)
 | 
						|
                if hook_result is not None:
 | 
						|
                    result = hook_result
 | 
						|
 | 
						|
        if bw_hook:
 | 
						|
            result = bw_hook.setup_output_hook(result)
 | 
						|
 | 
						|
        # Handle the non-full backward hooks
 | 
						|
        if non_full_backward_hooks:
 | 
						|
            var = result
 | 
						|
            while not isinstance(var, torch.Tensor):
 | 
						|
                if isinstance(var, dict):
 | 
						|
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
 | 
						|
                else:
 | 
						|
                    var = var[0]
 | 
						|
            grad_fn = var.grad_fn
 | 
						|
            if grad_fn is not None:
 | 
						|
                for hook in non_full_backward_hooks:
 | 
						|
                    wrapper = functools.partial(hook, self)
 | 
						|
                    functools.update_wrapper(wrapper, hook)
 | 
						|
                    grad_fn.register_hook(wrapper)
 | 
						|
                self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
 | 
						|
 | 
						|
        return result
 | 
						|
 | 
						|
    __call__ : Callable[..., Any] = _call_impl
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        self.__dict__.update(state)
 | 
						|
        # Support loading old checkpoints that don't have the following attrs:
 | 
						|
        if '_forward_pre_hooks' not in self.__dict__:
 | 
						|
            self._forward_pre_hooks = OrderedDict()
 | 
						|
        if '_state_dict_hooks' not in self.__dict__:
 | 
						|
            self._state_dict_hooks = OrderedDict()
 | 
						|
        if '_load_state_dict_pre_hooks' not in self.__dict__:
 | 
						|
            self._load_state_dict_pre_hooks = OrderedDict()
 | 
						|
        if '_non_persistent_buffers_set' not in self.__dict__:
 | 
						|
            self._non_persistent_buffers_set = set()
 | 
						|
        if '_is_full_backward_hook' not in self.__dict__:
 | 
						|
            self._is_full_backward_hook = None
 | 
						|
 | 
						|
    def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
 | 
						|
        if '_parameters' in self.__dict__:
 | 
						|
            _parameters = self.__dict__['_parameters']
 | 
						|
            if name in _parameters:
 | 
						|
                return _parameters[name]
 | 
						|
        if '_buffers' in self.__dict__:
 | 
						|
            _buffers = self.__dict__['_buffers']
 | 
						|
            if name in _buffers:
 | 
						|
                return _buffers[name]
 | 
						|
        if '_modules' in self.__dict__:
 | 
						|
            modules = self.__dict__['_modules']
 | 
						|
            if name in modules:
 | 
						|
                return modules[name]
 | 
						|
        raise AttributeError("'{}' object has no attribute '{}'".format(
 | 
						|
            type(self).__name__, name))
 | 
						|
 | 
						|
    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
 | 
						|
        def remove_from(*dicts_or_sets):
 | 
						|
            for d in dicts_or_sets:
 | 
						|
                if name in d:
 | 
						|
                    if isinstance(d, dict):
 | 
						|
                        del d[name]
 | 
						|
                    else:
 | 
						|
                        d.discard(name)
 | 
						|
 | 
						|
        params = self.__dict__.get('_parameters')
 | 
						|
        if isinstance(value, Parameter):
 | 
						|
            if params is None:
 | 
						|
                raise AttributeError(
 | 
						|
                    "cannot assign parameters before Module.__init__() call")
 | 
						|
            remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
 | 
						|
            self.register_parameter(name, value)
 | 
						|
        elif params is not None and name in params:
 | 
						|
            if value is not None:
 | 
						|
                raise TypeError("cannot assign '{}' as parameter '{}' "
 | 
						|
                                "(torch.nn.Parameter or None expected)"
 | 
						|
                                .format(torch.typename(value), name))
 | 
						|
            self.register_parameter(name, value)
 | 
						|
        else:
 | 
						|
            modules = self.__dict__.get('_modules')
 | 
						|
            if isinstance(value, Module):
 | 
						|
                if modules is None:
 | 
						|
                    raise AttributeError(
 | 
						|
                        "cannot assign module before Module.__init__() call")
 | 
						|
                remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
 | 
						|
                modules[name] = value
 | 
						|
            elif modules is not None and name in modules:
 | 
						|
                if value is not None:
 | 
						|
                    raise TypeError("cannot assign '{}' as child module '{}' "
 | 
						|
                                    "(torch.nn.Module or None expected)"
 | 
						|
                                    .format(torch.typename(value), name))
 | 
						|
                modules[name] = value
 | 
						|
            else:
 | 
						|
                buffers = self.__dict__.get('_buffers')
 | 
						|
                if buffers is not None and name in buffers:
 | 
						|
                    if value is not None and not isinstance(value, torch.Tensor):
 | 
						|
                        raise TypeError("cannot assign '{}' as buffer '{}' "
 | 
						|
                                        "(torch.Tensor or None expected)"
 | 
						|
                                        .format(torch.typename(value), name))
 | 
						|
                    buffers[name] = value
 | 
						|
                else:
 | 
						|
                    object.__setattr__(self, name, value)
 | 
						|
 | 
						|
    def __delattr__(self, name):
 | 
						|
        if name in self._parameters:
 | 
						|
            del self._parameters[name]
 | 
						|
        elif name in self._buffers:
 | 
						|
            del self._buffers[name]
 | 
						|
            self._non_persistent_buffers_set.discard(name)
 | 
						|
        elif name in self._modules:
 | 
						|
            del self._modules[name]
 | 
						|
        else:
 | 
						|
            object.__delattr__(self, name)
 | 
						|
 | 
						|
    def _register_state_dict_hook(self, hook):
 | 
						|
        r"""These hooks will be called with arguments: `self`, `state_dict`,
 | 
						|
        `prefix`, `local_metadata`, after the `state_dict` of `self` is set.
 | 
						|
        Note that only parameters and buffers of `self` or its children are
 | 
						|
        guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
 | 
						|
        inplace or return a new one.
 | 
						|
        """
 | 
						|
        handle = hooks.RemovableHandle(self._state_dict_hooks)
 | 
						|
        self._state_dict_hooks[handle.id] = hook
 | 
						|
        return handle
 | 
						|
 | 
						|
    def _save_to_state_dict(self, destination, prefix, keep_vars):
 | 
						|
        r"""Saves module state to `destination` dictionary, containing a state
 | 
						|
        of the module, but not its descendants. This is called on every
 | 
						|
        submodule in :meth:`~torch.nn.Module.state_dict`.
 | 
						|
 | 
						|
        In rare cases, subclasses can achieve class-specific behavior by
 | 
						|
        overriding this method with custom logic.
 | 
						|
 | 
						|
        Args:
 | 
						|
            destination (dict): a dict where state will be stored
 | 
						|
            prefix (str): the prefix for parameters and buffers used in this
 | 
						|
                module
 | 
						|
        """
 | 
						|
        for name, param in self._parameters.items():
 | 
						|
            if param is not None:
 | 
						|
                destination[prefix + name] = param if keep_vars else param.detach()
 | 
						|
        for name, buf in self._buffers.items():
 | 
						|
            if buf is not None and name not in self._non_persistent_buffers_set:
 | 
						|
                destination[prefix + name] = buf if keep_vars else buf.detach()
 | 
						|
        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
 | 
						|
        if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
 | 
						|
            destination[extra_state_key] = self.get_extra_state()
 | 
						|
 | 
						|
    # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
 | 
						|
    # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
 | 
						|
    T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor])
 | 
						|
 | 
						|
    @overload
 | 
						|
    def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
 | 
						|
        ...
 | 
						|
 | 
						|
    # TODO: Remove string escape once Python-3.6 no longer supported
 | 
						|
    # See https://github.com/python/mypy/issues/6904#issuecomment-496207426
 | 
						|
    @overload
 | 
						|
    def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> 'OrderedDict[str, Tensor]':
 | 
						|
        ...
 | 
						|
 | 
						|
    def state_dict(self, destination=None, prefix='', keep_vars=False):
 | 
						|
        r"""Returns a dictionary containing a whole state of the module.
 | 
						|
 | 
						|
        Both parameters and persistent buffers (e.g. running averages) are
 | 
						|
        included. Keys are corresponding parameter and buffer names.
 | 
						|
        Parameters and buffers set to ``None`` are not included.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            dict:
 | 
						|
                a dictionary containing a whole state of the module
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> module.state_dict().keys()
 | 
						|
            ['bias', 'weight']
 | 
						|
 | 
						|
        """
 | 
						|
        if destination is None:
 | 
						|
            destination = OrderedDict()
 | 
						|
            destination._metadata = OrderedDict()
 | 
						|
        destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
 | 
						|
        self._save_to_state_dict(destination, prefix, keep_vars)
 | 
						|
        for name, module in self._modules.items():
 | 
						|
            if module is not None:
 | 
						|
                module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
 | 
						|
        for hook in self._state_dict_hooks.values():
 | 
						|
            hook_result = hook(self, destination, prefix, local_metadata)
 | 
						|
            if hook_result is not None:
 | 
						|
                destination = hook_result
 | 
						|
        return destination
 | 
						|
 | 
						|
    def _register_load_state_dict_pre_hook(self, hook, with_module=False):
 | 
						|
        r"""These hooks will be called with arguments: `state_dict`, `prefix`,
 | 
						|
        `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
 | 
						|
        `error_msgs`, before loading `state_dict` into `self`. These arguments
 | 
						|
        are exactly the same as those of `_load_from_state_dict`.
 | 
						|
 | 
						|
        If ``with_module`` is ``True``, then the first argument to the hook is
 | 
						|
        an instance of the module.
 | 
						|
 | 
						|
        Arguments:
 | 
						|
            hook (Callable): Callable hook that will be invoked before
 | 
						|
                loading the state dict.
 | 
						|
            with_module (bool, optional): Whether or not to pass the module
 | 
						|
                instance to the hook as the first parameter.
 | 
						|
        """
 | 
						|
        handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
 | 
						|
        if with_module:
 | 
						|
            hook = functools.partial(hook, self)
 | 
						|
        self._load_state_dict_pre_hooks[handle.id] = hook
 | 
						|
        return handle
 | 
						|
 | 
						|
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
 | 
						|
                              missing_keys, unexpected_keys, error_msgs):
 | 
						|
        r"""Copies parameters and buffers from :attr:`state_dict` into only
 | 
						|
        this module, but not its descendants. This is called on every submodule
 | 
						|
        in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
 | 
						|
        module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
 | 
						|
        For state dicts without metadata, :attr:`local_metadata` is empty.
 | 
						|
        Subclasses can achieve class-specific backward compatible loading using
 | 
						|
        the version number at `local_metadata.get("version", None)`.
 | 
						|
 | 
						|
        .. note::
 | 
						|
            :attr:`state_dict` is not the same object as the input
 | 
						|
            :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
 | 
						|
            it can be modified.
 | 
						|
 | 
						|
        Args:
 | 
						|
            state_dict (dict): a dict containing parameters and
 | 
						|
                persistent buffers.
 | 
						|
            prefix (str): the prefix for parameters and buffers used in this
 | 
						|
                module
 | 
						|
            local_metadata (dict): a dict containing the metadata for this module.
 | 
						|
                See
 | 
						|
            strict (bool): whether to strictly enforce that the keys in
 | 
						|
                :attr:`state_dict` with :attr:`prefix` match the names of
 | 
						|
                parameters and buffers in this module
 | 
						|
            missing_keys (list of str): if ``strict=True``, add missing keys to
 | 
						|
                this list
 | 
						|
            unexpected_keys (list of str): if ``strict=True``, add unexpected
 | 
						|
                keys to this list
 | 
						|
            error_msgs (list of str): error messages should be added to this
 | 
						|
                list, and will be reported together in
 | 
						|
                :meth:`~torch.nn.Module.load_state_dict`
 | 
						|
        """
 | 
						|
        for hook in self._load_state_dict_pre_hooks.values():
 | 
						|
            hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
 | 
						|
 | 
						|
        persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
 | 
						|
        local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
 | 
						|
        local_state = {k: v for k, v in local_name_params if v is not None}
 | 
						|
 | 
						|
        for name, param in local_state.items():
 | 
						|
            key = prefix + name
 | 
						|
            if key in state_dict:
 | 
						|
                input_param = state_dict[key]
 | 
						|
                # This is used to avoid copying uninitialized parameters into
 | 
						|
                # non-lazy modules, since they dont have the hook to do the checks
 | 
						|
                # in such case, it will error when accessing the .shape attribute.
 | 
						|
                is_param_lazy = torch.nn.parameter.is_lazy(param)
 | 
						|
                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
 | 
						|
                if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
 | 
						|
                    input_param = input_param[0]
 | 
						|
 | 
						|
                if not is_param_lazy and input_param.shape != param.shape:
 | 
						|
                    # local shape should match the one in checkpoint
 | 
						|
                    error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
 | 
						|
                                      'the shape in current model is {}.'
 | 
						|
                                      .format(key, input_param.shape, param.shape))
 | 
						|
                    continue
 | 
						|
                try:
 | 
						|
                    with torch.no_grad():
 | 
						|
                        param.copy_(input_param)
 | 
						|
                except Exception as ex:
 | 
						|
                    error_msgs.append('While copying the parameter named "{}", '
 | 
						|
                                      'whose dimensions in the model are {} and '
 | 
						|
                                      'whose dimensions in the checkpoint are {}, '
 | 
						|
                                      'an exception occurred : {}.'
 | 
						|
                                      .format(key, param.size(), input_param.size(), ex.args))
 | 
						|
            elif strict:
 | 
						|
                missing_keys.append(key)
 | 
						|
 | 
						|
        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
 | 
						|
        if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
 | 
						|
            if extra_state_key in state_dict:
 | 
						|
                self.set_extra_state(state_dict[extra_state_key])
 | 
						|
            elif strict:
 | 
						|
                missing_keys.append(extra_state_key)
 | 
						|
        elif strict and (extra_state_key in state_dict):
 | 
						|
            unexpected_keys.append(extra_state_key)
 | 
						|
 | 
						|
        if strict:
 | 
						|
            for key in state_dict.keys():
 | 
						|
                if key.startswith(prefix) and key != extra_state_key:
 | 
						|
                    input_name = key[len(prefix):]
 | 
						|
                    input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
 | 
						|
                    if input_name not in self._modules and input_name not in local_state:
 | 
						|
                        unexpected_keys.append(key)
 | 
						|
 | 
						|
    def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
 | 
						|
                        strict: bool = True):
 | 
						|
        r"""Copies parameters and buffers from :attr:`state_dict` into
 | 
						|
        this module and its descendants. If :attr:`strict` is ``True``, then
 | 
						|
        the keys of :attr:`state_dict` must exactly match the keys returned
 | 
						|
        by this module's :meth:`~torch.nn.Module.state_dict` function.
 | 
						|
 | 
						|
        Args:
 | 
						|
            state_dict (dict): a dict containing parameters and
 | 
						|
                persistent buffers.
 | 
						|
            strict (bool, optional): whether to strictly enforce that the keys
 | 
						|
                in :attr:`state_dict` match the keys returned by this module's
 | 
						|
                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
 | 
						|
 | 
						|
        Returns:
 | 
						|
            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
 | 
						|
                * **missing_keys** is a list of str containing the missing keys
 | 
						|
                * **unexpected_keys** is a list of str containing the unexpected keys
 | 
						|
 | 
						|
        Note:
 | 
						|
            If a parameter or buffer is registered as ``None`` and its corresponding key
 | 
						|
            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
 | 
						|
            ``RuntimeError``.
 | 
						|
        """
 | 
						|
        missing_keys: List[str] = []
 | 
						|
        unexpected_keys: List[str] = []
 | 
						|
        error_msgs: List[str] = []
 | 
						|
 | 
						|
        # copy state_dict so _load_from_state_dict can modify it
 | 
						|
        metadata = getattr(state_dict, '_metadata', None)
 | 
						|
        state_dict = state_dict.copy()
 | 
						|
        if metadata is not None:
 | 
						|
            # mypy isn't aware that "_metadata" exists in state_dict
 | 
						|
            state_dict._metadata = metadata  # type: ignore[attr-defined]
 | 
						|
 | 
						|
        def load(module, prefix=''):
 | 
						|
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
 | 
						|
            module._load_from_state_dict(
 | 
						|
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
 | 
						|
            for name, child in module._modules.items():
 | 
						|
                if child is not None:
 | 
						|
                    load(child, prefix + name + '.')
 | 
						|
 | 
						|
        load(self)
 | 
						|
        del load
 | 
						|
 | 
						|
        if strict:
 | 
						|
            if len(unexpected_keys) > 0:
 | 
						|
                error_msgs.insert(
 | 
						|
                    0, 'Unexpected key(s) in state_dict: {}. '.format(
 | 
						|
                        ', '.join('"{}"'.format(k) for k in unexpected_keys)))
 | 
						|
            if len(missing_keys) > 0:
 | 
						|
                error_msgs.insert(
 | 
						|
                    0, 'Missing key(s) in state_dict: {}. '.format(
 | 
						|
                        ', '.join('"{}"'.format(k) for k in missing_keys)))
 | 
						|
 | 
						|
        if len(error_msgs) > 0:
 | 
						|
            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
 | 
						|
                               self.__class__.__name__, "\n\t".join(error_msgs)))
 | 
						|
        return _IncompatibleKeys(missing_keys, unexpected_keys)
 | 
						|
 | 
						|
    def _named_members(self, get_members_fn, prefix='', recurse=True):
 | 
						|
        r"""Helper method for yielding various names + members of modules."""
 | 
						|
        memo = set()
 | 
						|
        modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
 | 
						|
        for module_prefix, module in modules:
 | 
						|
            members = get_members_fn(module)
 | 
						|
            for k, v in members:
 | 
						|
                if v is None or v in memo:
 | 
						|
                    continue
 | 
						|
                memo.add(v)
 | 
						|
                name = module_prefix + ('.' if module_prefix else '') + k
 | 
						|
                yield name, v
 | 
						|
 | 
						|
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 | 
						|
        r"""Returns an iterator over module parameters.
 | 
						|
 | 
						|
        This is typically passed to an optimizer.
 | 
						|
 | 
						|
        Args:
 | 
						|
            recurse (bool): if True, then yields parameters of this module
 | 
						|
                and all submodules. Otherwise, yields only parameters that
 | 
						|
                are direct members of this module.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            Parameter: module parameter
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> for param in model.parameters():
 | 
						|
            >>>     print(type(param), param.size())
 | 
						|
            <class 'torch.Tensor'> (20L,)
 | 
						|
            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
 | 
						|
 | 
						|
        """
 | 
						|
        for name, param in self.named_parameters(recurse=recurse):
 | 
						|
            yield param
 | 
						|
 | 
						|
    def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
 | 
						|
        r"""Returns an iterator over module parameters, yielding both the
 | 
						|
        name of the parameter as well as the parameter itself.
 | 
						|
 | 
						|
        Args:
 | 
						|
            prefix (str): prefix to prepend to all parameter names.
 | 
						|
            recurse (bool): if True, then yields parameters of this module
 | 
						|
                and all submodules. Otherwise, yields only parameters that
 | 
						|
                are direct members of this module.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            (string, Parameter): Tuple containing the name and parameter
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> for name, param in self.named_parameters():
 | 
						|
            >>>    if name in ['bias']:
 | 
						|
            >>>        print(param.size())
 | 
						|
 | 
						|
        """
 | 
						|
        gen = self._named_members(
 | 
						|
            lambda module: module._parameters.items(),
 | 
						|
            prefix=prefix, recurse=recurse)
 | 
						|
        for elem in gen:
 | 
						|
            yield elem
 | 
						|
 | 
						|
    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
 | 
						|
        r"""Returns an iterator over module buffers.
 | 
						|
 | 
						|
        Args:
 | 
						|
            recurse (bool): if True, then yields buffers of this module
 | 
						|
                and all submodules. Otherwise, yields only buffers that
 | 
						|
                are direct members of this module.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            torch.Tensor: module buffer
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> for buf in model.buffers():
 | 
						|
            >>>     print(type(buf), buf.size())
 | 
						|
            <class 'torch.Tensor'> (20L,)
 | 
						|
            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
 | 
						|
 | 
						|
        """
 | 
						|
        for _, buf in self.named_buffers(recurse=recurse):
 | 
						|
            yield buf
 | 
						|
 | 
						|
    def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
 | 
						|
        r"""Returns an iterator over module buffers, yielding both the
 | 
						|
        name of the buffer as well as the buffer itself.
 | 
						|
 | 
						|
        Args:
 | 
						|
            prefix (str): prefix to prepend to all buffer names.
 | 
						|
            recurse (bool): if True, then yields buffers of this module
 | 
						|
                and all submodules. Otherwise, yields only buffers that
 | 
						|
                are direct members of this module.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            (string, torch.Tensor): Tuple containing the name and buffer
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> for name, buf in self.named_buffers():
 | 
						|
            >>>    if name in ['running_var']:
 | 
						|
            >>>        print(buf.size())
 | 
						|
 | 
						|
        """
 | 
						|
        gen = self._named_members(
 | 
						|
            lambda module: module._buffers.items(),
 | 
						|
            prefix=prefix, recurse=recurse)
 | 
						|
        for elem in gen:
 | 
						|
            yield elem
 | 
						|
 | 
						|
    def children(self) -> Iterator['Module']:
 | 
						|
        r"""Returns an iterator over immediate children modules.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            Module: a child module
 | 
						|
        """
 | 
						|
        for name, module in self.named_children():
 | 
						|
            yield module
 | 
						|
 | 
						|
    def named_children(self) -> Iterator[Tuple[str, 'Module']]:
 | 
						|
        r"""Returns an iterator over immediate children modules, yielding both
 | 
						|
        the name of the module as well as the module itself.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            (string, Module): Tuple containing a name and child module
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> for name, module in model.named_children():
 | 
						|
            >>>     if name in ['conv4', 'conv5']:
 | 
						|
            >>>         print(module)
 | 
						|
 | 
						|
        """
 | 
						|
        memo = set()
 | 
						|
        for name, module in self._modules.items():
 | 
						|
            if module is not None and module not in memo:
 | 
						|
                memo.add(module)
 | 
						|
                yield name, module
 | 
						|
 | 
						|
    def modules(self) -> Iterator['Module']:
 | 
						|
        r"""Returns an iterator over all modules in the network.
 | 
						|
 | 
						|
        Yields:
 | 
						|
            Module: a module in the network
 | 
						|
 | 
						|
        Note:
 | 
						|
            Duplicate modules are returned only once. In the following
 | 
						|
            example, ``l`` will be returned only once.
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> l = nn.Linear(2, 2)
 | 
						|
            >>> net = nn.Sequential(l, l)
 | 
						|
            >>> for idx, m in enumerate(net.modules()):
 | 
						|
                    print(idx, '->', m)
 | 
						|
 | 
						|
            0 -> Sequential(
 | 
						|
              (0): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
              (1): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            )
 | 
						|
            1 -> Linear(in_features=2, out_features=2, bias=True)
 | 
						|
 | 
						|
        """
 | 
						|
        for _, module in self.named_modules():
 | 
						|
            yield module
 | 
						|
 | 
						|
    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
 | 
						|
        r"""Returns an iterator over all modules in the network, yielding
 | 
						|
        both the name of the module as well as the module itself.
 | 
						|
 | 
						|
        Args:
 | 
						|
            memo: a memo to store the set of modules already added to the result
 | 
						|
            prefix: a prefix that will be added to the name of the module
 | 
						|
            remove_duplicate: whether to remove the duplicated module instances in the result
 | 
						|
                or not
 | 
						|
 | 
						|
        Yields:
 | 
						|
            (string, Module): Tuple of name and module
 | 
						|
 | 
						|
        Note:
 | 
						|
            Duplicate modules are returned only once. In the following
 | 
						|
            example, ``l`` will be returned only once.
 | 
						|
 | 
						|
        Example::
 | 
						|
 | 
						|
            >>> l = nn.Linear(2, 2)
 | 
						|
            >>> net = nn.Sequential(l, l)
 | 
						|
            >>> for idx, m in enumerate(net.named_modules()):
 | 
						|
                    print(idx, '->', m)
 | 
						|
 | 
						|
            0 -> ('', Sequential(
 | 
						|
              (0): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
              (1): Linear(in_features=2, out_features=2, bias=True)
 | 
						|
            ))
 | 
						|
            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
 | 
						|
 | 
						|
        """
 | 
						|
 | 
						|
        if memo is None:
 | 
						|
            memo = set()
 | 
						|
        if self not in memo:
 | 
						|
            if remove_duplicate:
 | 
						|
                memo.add(self)
 | 
						|
            yield prefix, self
 | 
						|
            for name, module in self._modules.items():
 | 
						|
                if module is None:
 | 
						|
                    continue
 | 
						|
                submodule_prefix = prefix + ('.' if prefix else '') + name
 | 
						|
                for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
 | 
						|
                    yield m
 | 
						|
 | 
						|
    def train(self: T, mode: bool = True) -> T:
 | 
						|
        r"""Sets the module in training mode.
 | 
						|
 | 
						|
        This has any effect only on certain modules. See documentations of
 | 
						|
        particular modules for details of their behaviors in training/evaluation
 | 
						|
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
 | 
						|
        etc.
 | 
						|
 | 
						|
        Args:
 | 
						|
            mode (bool): whether to set training mode (``True``) or evaluation
 | 
						|
                         mode (``False``). Default: ``True``.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        if not isinstance(mode, bool):
 | 
						|
            raise ValueError("training mode is expected to be boolean")
 | 
						|
        self.training = mode
 | 
						|
        for module in self.children():
 | 
						|
            module.train(mode)
 | 
						|
        return self
 | 
						|
 | 
						|
    def eval(self: T) -> T:
 | 
						|
        r"""Sets the module in evaluation mode.
 | 
						|
 | 
						|
        This has any effect only on certain modules. See documentations of
 | 
						|
        particular modules for details of their behaviors in training/evaluation
 | 
						|
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
 | 
						|
        etc.
 | 
						|
 | 
						|
        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
 | 
						|
 | 
						|
        See :ref:`locally-disable-grad-doc` for a comparison between
 | 
						|
        `.eval()` and several similar mechanisms that may be confused with it.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        return self.train(False)
 | 
						|
 | 
						|
    def requires_grad_(self: T, requires_grad: bool = True) -> T:
 | 
						|
        r"""Change if autograd should record operations on parameters in this
 | 
						|
        module.
 | 
						|
 | 
						|
        This method sets the parameters' :attr:`requires_grad` attributes
 | 
						|
        in-place.
 | 
						|
 | 
						|
        This method is helpful for freezing part of the module for finetuning
 | 
						|
        or training parts of a model individually (e.g., GAN training).
 | 
						|
 | 
						|
        See :ref:`locally-disable-grad-doc` for a comparison between
 | 
						|
        `.requires_grad_()` and several similar mechanisms that may be confused with it.
 | 
						|
 | 
						|
        Args:
 | 
						|
            requires_grad (bool): whether autograd should record operations on
 | 
						|
                                  parameters in this module. Default: ``True``.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Module: self
 | 
						|
        """
 | 
						|
        for p in self.parameters():
 | 
						|
            p.requires_grad_(requires_grad)
 | 
						|
        return self
 | 
						|
 | 
						|
    def zero_grad(self, set_to_none: bool = False) -> None:
 | 
						|
        r"""Sets gradients of all model parameters to zero. See similar function
 | 
						|
        under :class:`torch.optim.Optimizer` for more context.
 | 
						|
 | 
						|
        Args:
 | 
						|
            set_to_none (bool): instead of setting to zero, set the grads to None.
 | 
						|
                See :meth:`torch.optim.Optimizer.zero_grad` for details.
 | 
						|
        """
 | 
						|
        if getattr(self, '_is_replica', False):
 | 
						|
            warnings.warn(
 | 
						|
                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
 | 
						|
                "The parameters are copied (in a differentiable manner) from the original module. "
 | 
						|
                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
 | 
						|
                "If you need gradients in your forward method, consider using autograd.grad instead.")
 | 
						|
 | 
						|
        for p in self.parameters():
 | 
						|
            if p.grad is not None:
 | 
						|
                if set_to_none:
 | 
						|
                    p.grad = None
 | 
						|
                else:
 | 
						|
                    if p.grad.grad_fn is not None:
 | 
						|
                        p.grad.detach_()
 | 
						|
                    else:
 | 
						|
                        p.grad.requires_grad_(False)
 | 
						|
                    p.grad.zero_()
 | 
						|
 | 
						|
    def share_memory(self: T) -> T:
 | 
						|
        r"""See :meth:`torch.Tensor.share_memory_`"""
 | 
						|
        return self._apply(lambda t: t.share_memory_())
 | 
						|
 | 
						|
    def _get_name(self):
 | 
						|
        return self.__class__.__name__
 | 
						|
 | 
						|
    def extra_repr(self) -> str:
 | 
						|
        r"""Set the extra representation of the module
 | 
						|
 | 
						|
        To print customized extra information, you should re-implement
 | 
						|
        this method in your own modules. Both single-line and multi-line
 | 
						|
        strings are acceptable.
 | 
						|
        """
 | 
						|
        return ''
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        # We treat the extra repr like the sub-module, one item per line
 | 
						|
        extra_lines = []
 | 
						|
        extra_repr = self.extra_repr()
 | 
						|
        # empty string will be split into list ['']
 | 
						|
        if extra_repr:
 | 
						|
            extra_lines = extra_repr.split('\n')
 | 
						|
        child_lines = []
 | 
						|
        for key, module in self._modules.items():
 | 
						|
            mod_str = repr(module)
 | 
						|
            mod_str = _addindent(mod_str, 2)
 | 
						|
            child_lines.append('(' + key + '): ' + mod_str)
 | 
						|
        lines = extra_lines + child_lines
 | 
						|
 | 
						|
        main_str = self._get_name() + '('
 | 
						|
        if lines:
 | 
						|
            # simple one-liner info, which most builtin Modules will use
 | 
						|
            if len(extra_lines) == 1 and not child_lines:
 | 
						|
                main_str += extra_lines[0]
 | 
						|
            else:
 | 
						|
                main_str += '\n  ' + '\n  '.join(lines) + '\n'
 | 
						|
 | 
						|
        main_str += ')'
 | 
						|
        return main_str
 | 
						|
 | 
						|
    def __dir__(self):
 | 
						|
        module_attrs = dir(self.__class__)
 | 
						|
        attrs = list(self.__dict__.keys())
 | 
						|
        parameters = list(self._parameters.keys())
 | 
						|
        modules = list(self._modules.keys())
 | 
						|
        buffers = list(self._buffers.keys())
 | 
						|
        keys = module_attrs + attrs + parameters + modules + buffers
 | 
						|
 | 
						|
        # Eliminate attrs that are not legal Python variable names
 | 
						|
        keys = [key for key in keys if not key[0].isdigit()]
 | 
						|
 | 
						|
        return sorted(keys)
 | 
						|
 | 
						|
    def _replicate_for_data_parallel(self):
 | 
						|
        replica = self.__new__(type(self))
 | 
						|
        replica.__dict__ = self.__dict__.copy()
 | 
						|
 | 
						|
        # replicas do not have parameters themselves, the replicas reference the original
 | 
						|
        # module.
 | 
						|
        replica._parameters = OrderedDict()
 | 
						|
        replica._buffers = replica._buffers.copy()
 | 
						|
        replica._modules = replica._modules.copy()
 | 
						|
        replica._is_replica = True
 | 
						|
 | 
						|
        return replica
 |