mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Original PR: https://github.com/pytorch/pytorch/pull/37419 cc mattip suo Pull Request resolved: https://github.com/pytorch/pytorch/pull/37778 Differential Revision: D21385774 Pulled By: ezyang fbshipit-source-id: 5de532faab8bae132736b6b5189e0ee2ac9935be
1041 lines
47 KiB
Python
1041 lines
47 KiB
Python
import math
|
|
import warnings
|
|
import numbers
|
|
from typing import Tuple, Optional, overload
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from .module import Module
|
|
from ..parameter import Parameter
|
|
from ..utils.rnn import PackedSequence
|
|
from .. import init
|
|
from ... import _VF
|
|
|
|
_rnn_impls = {
|
|
'RNN_TANH': _VF.rnn_tanh,
|
|
'RNN_RELU': _VF.rnn_relu,
|
|
}
|
|
|
|
|
|
def apply_permutation(tensor, permutation, dim=1):
|
|
# type: (Tensor, Tensor, int) -> Tensor
|
|
return tensor.index_select(dim, permutation)
|
|
|
|
|
|
class RNNBase(Module):
|
|
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
|
|
'batch_first', 'dropout', 'bidirectional']
|
|
|
|
def __init__(self, mode, input_size, hidden_size,
|
|
num_layers=1, bias=True, batch_first=False,
|
|
dropout=0., bidirectional=False):
|
|
super(RNNBase, self).__init__()
|
|
self.mode = mode
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = float(dropout)
|
|
self.bidirectional = bidirectional
|
|
num_directions = 2 if bidirectional else 1
|
|
|
|
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
|
isinstance(dropout, bool):
|
|
raise ValueError("dropout should be a number in range [0, 1] "
|
|
"representing the probability of an element being "
|
|
"zeroed")
|
|
if dropout > 0 and num_layers == 1:
|
|
warnings.warn("dropout option adds dropout after all but last "
|
|
"recurrent layer, so non-zero dropout expects "
|
|
"num_layers greater than 1, but got dropout={} and "
|
|
"num_layers={}".format(dropout, num_layers))
|
|
|
|
if mode == 'LSTM':
|
|
gate_size = 4 * hidden_size
|
|
elif mode == 'GRU':
|
|
gate_size = 3 * hidden_size
|
|
elif mode == 'RNN_TANH':
|
|
gate_size = hidden_size
|
|
elif mode == 'RNN_RELU':
|
|
gate_size = hidden_size
|
|
else:
|
|
raise ValueError("Unrecognized RNN mode: " + mode)
|
|
|
|
self._flat_weights_names = []
|
|
self._all_weights = []
|
|
for layer in range(num_layers):
|
|
for direction in range(num_directions):
|
|
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
|
|
|
|
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
|
|
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
|
|
b_ih = Parameter(torch.Tensor(gate_size))
|
|
# Second bias vector included for CuDNN compatibility. Only one
|
|
# bias vector is needed in standard definition.
|
|
b_hh = Parameter(torch.Tensor(gate_size))
|
|
layer_params = (w_ih, w_hh, b_ih, b_hh)
|
|
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
|
|
if bias:
|
|
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
|
|
param_names = [x.format(layer, suffix) for x in param_names]
|
|
|
|
for name, param in zip(param_names, layer_params):
|
|
setattr(self, name, param)
|
|
self._flat_weights_names.extend(param_names)
|
|
self._all_weights.append(param_names)
|
|
|
|
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
|
self.flatten_parameters()
|
|
self.reset_parameters()
|
|
|
|
def __setattr__(self, attr, value):
|
|
if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
|
|
# keep self._flat_weights up to date if you do self.weight = ...
|
|
idx = self._flat_weights_names.index(attr)
|
|
self._flat_weights[idx] = value
|
|
super(RNNBase, self).__setattr__(attr, value)
|
|
|
|
def flatten_parameters(self):
|
|
"""Resets parameter data pointer so that they can use faster code paths.
|
|
|
|
Right now, this works only if the module is on the GPU and cuDNN is enabled.
|
|
Otherwise, it's a no-op.
|
|
"""
|
|
# Short-circuits if _flat_weights is only partially instantiated
|
|
if len(self._flat_weights) != len(self._flat_weights_names):
|
|
return
|
|
|
|
for w in self._flat_weights:
|
|
if not torch.is_tensor(w):
|
|
return
|
|
# Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
|
|
# or the tensors in _flat_weights are of different dtypes
|
|
|
|
first_fw = self._flat_weights[0]
|
|
dtype = first_fw.dtype
|
|
for fw in self._flat_weights:
|
|
if (not torch.is_tensor(fw.data) or not (fw.data.dtype == dtype) or
|
|
not fw.data.is_cuda or
|
|
not torch.backends.cudnn.is_acceptable(fw.data)):
|
|
return
|
|
|
|
# If any parameters alias, we fall back to the slower, copying code path. This is
|
|
# a sufficient check, because overlapping parameter buffers that don't completely
|
|
# alias would break the assumptions of the uniqueness check in
|
|
# Module.named_parameters().
|
|
unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights)
|
|
if len(unique_data_ptrs) != len(self._flat_weights):
|
|
return
|
|
|
|
with torch.cuda.device_of(first_fw):
|
|
import torch.backends.cudnn.rnn as rnn
|
|
|
|
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
|
# an inplace operation on self._flat_weights
|
|
with torch.no_grad():
|
|
if torch._use_cudnn_rnn_flatten_weight():
|
|
torch._cudnn_rnn_flatten_weight(
|
|
self._flat_weights, (4 if self.bias else 2),
|
|
self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers,
|
|
self.batch_first, bool(self.bidirectional))
|
|
|
|
def _apply(self, fn):
|
|
ret = super(RNNBase, self)._apply(fn)
|
|
|
|
# Resets _flat_weights
|
|
# Note: be v. careful before removing this, as 3rd party device types
|
|
# likely rely on this behavior to properly .to() modules like LSTM.
|
|
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
|
# Flattens params (on CUDA)
|
|
self.flatten_parameters()
|
|
|
|
return ret
|
|
|
|
def reset_parameters(self):
|
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
|
for weight in self.parameters():
|
|
init.uniform_(weight, -stdv, stdv)
|
|
|
|
def check_input(self, input, batch_sizes):
|
|
# type: (Tensor, Optional[Tensor]) -> None
|
|
expected_input_dim = 2 if batch_sizes is not None else 3
|
|
if input.dim() != expected_input_dim:
|
|
raise RuntimeError(
|
|
'input must have {} dimensions, got {}'.format(
|
|
expected_input_dim, input.dim()))
|
|
if self.input_size != input.size(-1):
|
|
raise RuntimeError(
|
|
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
|
self.input_size, input.size(-1)))
|
|
|
|
def get_expected_hidden_size(self, input, batch_sizes):
|
|
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
|
if batch_sizes is not None:
|
|
mini_batch = batch_sizes[0]
|
|
mini_batch = int(mini_batch)
|
|
else:
|
|
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
expected_hidden_size = (self.num_layers * num_directions,
|
|
mini_batch, self.hidden_size)
|
|
return expected_hidden_size
|
|
|
|
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
|
# type: (Tensor, Tuple[int, int, int], str) -> None
|
|
if hx.size() != expected_hidden_size:
|
|
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
|
|
|
|
def check_forward_args(self, input, hidden, batch_sizes):
|
|
# type: (Tensor, Tensor, Optional[Tensor]) -> None
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(hidden, expected_hidden_size)
|
|
|
|
def permute_hidden(self, hx, permutation):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx, permutation)
|
|
|
|
def forward(self, input, hx=None):
|
|
is_packed = isinstance(input, PackedSequence)
|
|
if is_packed:
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
else:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
hx = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
_impl = _rnn_impls[self.mode]
|
|
if batch_sizes is None:
|
|
result = _impl(input, hx, self._flat_weights, self.bias, self.num_layers,
|
|
self.dropout, self.training, self.bidirectional, self.batch_first)
|
|
else:
|
|
result = _impl(input, batch_sizes, hx, self._flat_weights, self.bias,
|
|
self.num_layers, self.dropout, self.training, self.bidirectional)
|
|
output = result[0]
|
|
hidden = result[1]
|
|
|
|
if is_packed:
|
|
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
def extra_repr(self):
|
|
s = '{input_size}, {hidden_size}'
|
|
if self.num_layers != 1:
|
|
s += ', num_layers={num_layers}'
|
|
if self.bias is not True:
|
|
s += ', bias={bias}'
|
|
if self.batch_first is not False:
|
|
s += ', batch_first={batch_first}'
|
|
if self.dropout != 0:
|
|
s += ', dropout={dropout}'
|
|
if self.bidirectional is not False:
|
|
s += ', bidirectional={bidirectional}'
|
|
return s.format(**self.__dict__)
|
|
|
|
def __setstate__(self, d):
|
|
super(RNNBase, self).__setstate__(d)
|
|
if 'all_weights' in d:
|
|
self._all_weights = d['all_weights']
|
|
|
|
if isinstance(self._all_weights[0][0], str):
|
|
return
|
|
num_layers = self.num_layers
|
|
num_directions = 2 if self.bidirectional else 1
|
|
self._flat_weights_names = []
|
|
self._all_weights = []
|
|
for layer in range(num_layers):
|
|
for direction in range(num_directions):
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
|
|
weights = [x.format(layer, suffix) for x in weights]
|
|
if self.bias:
|
|
self._all_weights += [weights]
|
|
self._flat_weights_names.extend(weights)
|
|
else:
|
|
self._all_weights += [weights[:2]]
|
|
self._flat_weights_names.extend(weights[:2])
|
|
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
|
|
|
@property
|
|
def all_weights(self):
|
|
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
|
|
|
|
def _replicate_for_data_parallel(self):
|
|
replica = super(RNNBase, self)._replicate_for_data_parallel()
|
|
# Need to copy these caches, otherwise the replica will share the same
|
|
# flat weights list.
|
|
replica._flat_weights = replica._flat_weights[:]
|
|
replica._flat_weights_names = replica._flat_weights_names[:]
|
|
return replica
|
|
|
|
|
|
class RNN(RNNBase):
|
|
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
|
|
input sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
|
|
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
|
|
previous layer at time `t-1` or the initial hidden state at time `0`.
|
|
If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two RNNs together to form a `stacked RNN`,
|
|
with the second RNN taking in outputs of the first RNN and
|
|
computing the final results. Default: 1
|
|
nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as `(batch, seq, feature)`. Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
RNN layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
|
|
|
|
Inputs: input, h_0
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence. The input can also be a packed variable length
|
|
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
|
|
or :func:`torch.nn.utils.rnn.pack_sequence`
|
|
for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
Defaults to zero if not provided. If the RNN is bidirectional,
|
|
num_directions should be 2, else it should be 1.
|
|
|
|
Outputs: output, h_n
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features (`h_t`) from the last layer of the RNN,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has
|
|
been given as the input, the output will also be a packed sequence.
|
|
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`.
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
|
|
|
|
Shape:
|
|
- Input1: :math:`(L, N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
|
|
- Input2: :math:`(S, N, H_{out})` tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
:math:`H_{out}=\text{hidden\_size}`
|
|
Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
|
|
If the RNN is bidirectional, num_directions should be 2, else it should be 1.
|
|
- Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
|
|
- Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
|
|
of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
|
|
`(hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
|
|
of shape `(hidden_size, hidden_size)`
|
|
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
|
|
of shape `(hidden_size)`
|
|
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
|
|
of shape `(hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.RNN(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.nonlinearity = kwargs.pop('nonlinearity', 'tanh')
|
|
if self.nonlinearity == 'tanh':
|
|
mode = 'RNN_TANH'
|
|
elif self.nonlinearity == 'relu':
|
|
mode = 'RNN_RELU'
|
|
else:
|
|
raise ValueError("Unknown nonlinearity '{}'".format(self.nonlinearity))
|
|
super(RNN, self).__init__(mode, *args, **kwargs)
|
|
|
|
|
|
# XXX: LSTM and GRU implementation is different from RNNBase, this is because:
|
|
# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
|
|
# its current state could not support the python Union Type or Any Type
|
|
# 2. TorchScript static typing does not allow a Function or Callable type in
|
|
# Dict values, so we have to separately call _VF instead of using _rnn_impls
|
|
# 3. This is temporary only and in the transition state that we want to make it
|
|
# on time for the release
|
|
#
|
|
# More discussion details in https://github.com/pytorch/pytorch/pull/23266
|
|
#
|
|
# TODO: remove the overriding implementations for LSTM and GRU when TorchScript
|
|
# support expressing these two modules generally.
|
|
class LSTM(RNNBase):
|
|
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
|
|
sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
\begin{array}{ll} \\
|
|
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
|
|
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
|
|
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
|
|
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
|
|
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
|
|
h_t = o_t \odot \tanh(c_t) \\
|
|
\end{array}
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
|
|
state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
|
|
is the hidden state of the layer at time `t-1` or the initial hidden
|
|
state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
|
|
:math:`o_t` are the input, forget, cell, and output gates, respectively.
|
|
:math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
|
|
|
|
In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
|
|
(:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
|
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
|
|
variable which is :math:`0` with probability :attr:`dropout`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two LSTMs together to form a `stacked LSTM`,
|
|
with the second LSTM taking in outputs of the first LSTM and
|
|
computing the final results. Default: 1
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
LSTM layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
|
|
|
|
Inputs: input, (h_0, c_0)
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence.
|
|
The input can also be a packed variable length sequence.
|
|
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
|
|
:func:`torch.nn.utils.rnn.pack_sequence` for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
If the LSTM is bidirectional, num_directions should be 2, else it should be 1.
|
|
- **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial cell state for each element in the batch.
|
|
|
|
If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
|
|
|
|
|
|
Outputs: output, (h_n, c_n)
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features `(h_t)` from the last layer of the LSTM,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
|
|
given as the input, the output will also be a packed sequence.
|
|
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`.
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*.
|
|
- **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the cell state for `t = seq_len`.
|
|
|
|
Attributes:
|
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
|
`(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
|
|
Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
|
|
`(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`
|
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
|
|
`(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
|
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
|
|
`(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.LSTM(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> c0 = torch.randn(2, 3, 20)
|
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
super(LSTM, self).__init__('LSTM', *args, **kwargs)
|
|
|
|
def check_forward_args(self, input, hidden, batch_sizes):
|
|
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(hidden[0], expected_hidden_size,
|
|
'Expected hidden[0] size {}, got {}')
|
|
self.check_hidden_size(hidden[1], expected_hidden_size,
|
|
'Expected hidden[1] size {}, got {}')
|
|
|
|
def permute_hidden(self, hx, permutation):
|
|
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
|
pass
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
|
|
pass
|
|
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
orig_input = input
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
else:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
zeros = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
if batch_sizes is None:
|
|
result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
|
|
self.dropout, self.training, self.bidirectional, self.batch_first)
|
|
else:
|
|
result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
|
|
self.num_layers, self.dropout, self.training, self.bidirectional)
|
|
output = result[0]
|
|
hidden = result[1:]
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
|
else:
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
|
|
class GRU(RNNBase):
|
|
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
\begin{array}{ll}
|
|
r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
|
|
z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
|
|
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
|
|
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
|
|
\end{array}
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
|
|
at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
|
|
at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
|
|
:math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
|
|
:math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
|
|
|
In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
|
|
(:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
|
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
|
|
variable which is :math:`0` with probability :attr:`dropout`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two GRUs together to form a `stacked GRU`,
|
|
with the second GRU taking in outputs of the first GRU and
|
|
computing the final results. Default: 1
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
GRU layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
|
|
|
|
Inputs: input, h_0
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence. The input can also be a packed variable length
|
|
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
|
|
for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
Defaults to zero if not provided. If the RNN is bidirectional,
|
|
num_directions should be 2, else it should be 1.
|
|
|
|
Outputs: output, h_n
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features h_t from the last layer of the GRU,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
|
|
given as the input, the output will also be a packed sequence.
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
|
|
|
|
Shape:
|
|
- Input1: :math:`(L, N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
|
|
- Input2: :math:`(S, N, H_{out})` tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
:math:`H_{out}=\text{hidden\_size}`
|
|
Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
|
|
If the RNN is bidirectional, num_directions should be 2, else it should be 1.
|
|
- Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
|
|
- Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
|
(W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
|
|
Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
|
|
(W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
|
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
|
|
(b_ir|b_iz|b_in), of shape `(3*hidden_size)`
|
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
|
|
(b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.GRU(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
super(GRU, self).__init__('GRU', *args, **kwargs)
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
|
pass
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
|
|
pass
|
|
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
orig_input = input
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
else:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
hx = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
if batch_sizes is None:
|
|
result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
|
|
self.dropout, self.training, self.bidirectional, self.batch_first)
|
|
else:
|
|
result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
|
|
self.num_layers, self.dropout, self.training, self.bidirectional)
|
|
output = result[0]
|
|
hidden = result[1]
|
|
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
|
else:
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
|
|
class RNNCellBase(Module):
|
|
__constants__ = ['input_size', 'hidden_size', 'bias']
|
|
|
|
def __init__(self, input_size, hidden_size, bias, num_chunks):
|
|
super(RNNCellBase, self).__init__()
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.bias = bias
|
|
self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size))
|
|
self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size))
|
|
if bias:
|
|
self.bias_ih = Parameter(torch.Tensor(num_chunks * hidden_size))
|
|
self.bias_hh = Parameter(torch.Tensor(num_chunks * hidden_size))
|
|
else:
|
|
self.register_parameter('bias_ih', None)
|
|
self.register_parameter('bias_hh', None)
|
|
self.reset_parameters()
|
|
|
|
def extra_repr(self):
|
|
s = '{input_size}, {hidden_size}'
|
|
if 'bias' in self.__dict__ and self.bias is not True:
|
|
s += ', bias={bias}'
|
|
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
|
|
s += ', nonlinearity={nonlinearity}'
|
|
return s.format(**self.__dict__)
|
|
|
|
def check_forward_input(self, input):
|
|
if input.size(1) != self.input_size:
|
|
raise RuntimeError(
|
|
"input has inconsistent input_size: got {}, expected {}".format(
|
|
input.size(1), self.input_size))
|
|
|
|
def check_forward_hidden(self, input, hx, hidden_label=''):
|
|
# type: (Tensor, Tensor, str) -> None
|
|
if input.size(0) != hx.size(0):
|
|
raise RuntimeError(
|
|
"Input batch size {} doesn't match hidden{} batch size {}".format(
|
|
input.size(0), hidden_label, hx.size(0)))
|
|
|
|
if hx.size(1) != self.hidden_size:
|
|
raise RuntimeError(
|
|
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
|
|
hidden_label, hx.size(1), self.hidden_size))
|
|
|
|
def reset_parameters(self):
|
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
|
for weight in self.parameters():
|
|
init.uniform_(weight, -stdv, stdv)
|
|
|
|
|
|
class RNNCell(RNNCellBase):
|
|
r"""An Elman RNN cell with tanh or ReLU non-linearity.
|
|
|
|
.. math::
|
|
|
|
h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
|
|
|
|
If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
|
|
|
|
Inputs: input, hidden
|
|
- **input** of shape `(batch, input_size)`: tensor containing input features
|
|
- **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
|
|
state for each element in the batch.
|
|
Defaults to zero if not provided.
|
|
|
|
Outputs: h'
|
|
- **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Shape:
|
|
- Input1: :math:`(N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}` = `input_size`
|
|
- Input2: :math:`(N, H_{out})` tensor containing the initial hidden
|
|
state for each element in the batch where :math:`H_{out}` = `hidden_size`
|
|
Defaults to zero if not provided.
|
|
- Output: :math:`(N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih: the learnable input-hidden weights, of shape
|
|
`(hidden_size, input_size)`
|
|
weight_hh: the learnable hidden-hidden weights, of shape
|
|
`(hidden_size, hidden_size)`
|
|
bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
|
|
bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.RNNCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx = rnn(input[i], hx)
|
|
output.append(hx)
|
|
"""
|
|
__constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
|
|
|
|
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
|
|
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
|
|
self.nonlinearity = nonlinearity
|
|
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
if self.nonlinearity == "tanh":
|
|
ret = _VF.rnn_tanh_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|
|
elif self.nonlinearity == "relu":
|
|
ret = _VF.rnn_relu_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|
|
else:
|
|
ret = input # TODO: remove when jit supports exception flow
|
|
raise RuntimeError(
|
|
"Unknown nonlinearity: {}".format(self.nonlinearity))
|
|
return ret
|
|
|
|
|
|
class LSTMCell(RNNCellBase):
|
|
r"""A long short-term memory (LSTM) cell.
|
|
|
|
.. math::
|
|
|
|
\begin{array}{ll}
|
|
i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
|
|
f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
|
|
g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
|
|
o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
|
|
c' = f * c + i * g \\
|
|
h' = o * \tanh(c') \\
|
|
\end{array}
|
|
|
|
where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and
|
|
`b_hh`. Default: ``True``
|
|
|
|
Inputs: input, (h_0, c_0)
|
|
- **input** of shape `(batch, input_size)`: tensor containing input features
|
|
- **h_0** of shape `(batch, hidden_size)`: tensor containing the initial hidden
|
|
state for each element in the batch.
|
|
- **c_0** of shape `(batch, hidden_size)`: tensor containing the initial cell state
|
|
for each element in the batch.
|
|
|
|
If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
|
|
|
|
Outputs: (h_1, c_1)
|
|
- **h_1** of shape `(batch, hidden_size)`: tensor containing the next hidden state
|
|
for each element in the batch
|
|
- **c_1** of shape `(batch, hidden_size)`: tensor containing the next cell state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih: the learnable input-hidden weights, of shape
|
|
`(4*hidden_size, input_size)`
|
|
weight_hh: the learnable hidden-hidden weights, of shape
|
|
`(4*hidden_size, hidden_size)`
|
|
bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
|
|
bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.LSTMCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> cx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx, cx = rnn(input[i], (hx, cx))
|
|
output.append(hx)
|
|
"""
|
|
|
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
|
|
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
self.check_forward_hidden(input, hx[0], '[0]')
|
|
self.check_forward_hidden(input, hx[1], '[1]')
|
|
return _VF.lstm_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|
|
|
|
|
|
class GRUCell(RNNCellBase):
|
|
r"""A gated recurrent unit (GRU) cell
|
|
|
|
.. math::
|
|
|
|
\begin{array}{ll}
|
|
r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
|
|
z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
|
|
n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
|
|
h' = (1 - z) * n + z * h
|
|
\end{array}
|
|
|
|
where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and
|
|
`b_hh`. Default: ``True``
|
|
|
|
Inputs: input, hidden
|
|
- **input** of shape `(batch, input_size)`: tensor containing input features
|
|
- **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
|
|
state for each element in the batch.
|
|
Defaults to zero if not provided.
|
|
|
|
Outputs: h'
|
|
- **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Shape:
|
|
- Input1: :math:`(N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}` = `input_size`
|
|
- Input2: :math:`(N, H_{out})` tensor containing the initial hidden
|
|
state for each element in the batch where :math:`H_{out}` = `hidden_size`
|
|
Defaults to zero if not provided.
|
|
- Output: :math:`(N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih: the learnable input-hidden weights, of shape
|
|
`(3*hidden_size, input_size)`
|
|
weight_hh: the learnable hidden-hidden weights, of shape
|
|
`(3*hidden_size, hidden_size)`
|
|
bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
|
|
bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.GRUCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx = rnn(input[i], hx)
|
|
output.append(hx)
|
|
"""
|
|
|
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
|
|
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
return _VF.gru_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|