mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Revert D5589309: modify _LSTM into _RNN to adapt GRU
Summary: This reverts commit f5af67dfe0842acd68223f6da3e96a81639e8049 bypass-lint Differential Revision: D5589309 fbshipit-source-id: 79b0a3a9455829c3899472a1368ef36dc75f6e14
This commit is contained in:
committed by
Facebook Github Bot
parent
b91c2f5064
commit
a7be496fe2
@ -1,4 +1,4 @@
|
||||
# @package rnn_cell
|
||||
## @package rnn_cell
|
||||
# Module caffe2.python.rnn_cell
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -35,7 +35,6 @@ class RNNCell(object):
|
||||
As a result base class will provice apply_over_sequence method, which
|
||||
allows you to apply recurrent operations over a sequence of any length.
|
||||
'''
|
||||
|
||||
def __init__(self, name, forward_only=False):
|
||||
self.name = name
|
||||
self.recompute_blobs = []
|
||||
@ -907,7 +906,7 @@ class MILSTMWithAttentionCell(AttentionCell):
|
||||
)
|
||||
|
||||
|
||||
def _RNN(
|
||||
def _LSTM(
|
||||
cell_class,
|
||||
model,
|
||||
input_blob,
|
||||
@ -924,57 +923,51 @@ def _RNN(
|
||||
drop_states=False,
|
||||
return_last_layer_only=True,
|
||||
static_rnn_unroll_size=None,
|
||||
no_cell_state=False,
|
||||
):
|
||||
'''
|
||||
Adds a standard LSTM/MILSTM/GRU recurrent network operator to a model.
|
||||
Adds a standard LSTM recurrent network operator to a model.
|
||||
|
||||
cell_class: LSTMCell, GRUCell or compatible subclass.
|
||||
cell_class: LSTMCell or compatible subclass
|
||||
|
||||
model: ModelHelper object new operators would be added to.
|
||||
model: ModelHelper object new operators would be added to
|
||||
|
||||
input_blob: the input sequence in a format T x N x D,
|
||||
where T is sequence size, N - batch size and D - input dimension.
|
||||
input_blob: the input sequence in a format T x N x D
|
||||
where T is sequence size, N - batch size and D - input dimension
|
||||
|
||||
seq_lengths: blob containing sequence lengths which would be passed to
|
||||
LSTMUnit operator.
|
||||
LSTMUnit operator
|
||||
|
||||
initial_states: a list of (2 * num_layers) blobs representing the initial
|
||||
hidden and cell states of each layer For LSTM classes and a list
|
||||
of num_layers blobs for GRU. If this argument is None,
|
||||
hidden and cell states of each layer. If this argument is None,
|
||||
these states will be added to the model as network parameters.
|
||||
|
||||
dim_in: input dimension.
|
||||
dim_in: input dimension
|
||||
|
||||
dim_out: number of units per RNN layer
|
||||
(use int for single-layer RNN, list of ints for multi-layer).
|
||||
dim_out: number of units per LSTM layer
|
||||
(use int for single-layer LSTM, list of ints for multi-layer)
|
||||
|
||||
outputs_with_grads : position indices of output blobs for LAST LAYER which
|
||||
will receive external error gradient during backpropagation.
|
||||
These outputs are: (h_all, h_last, c_all, c_last) for LSTM classes
|
||||
and (h_all, h_last) for GRU.
|
||||
These outputs are: (h_all, h_last, c_all, c_last)
|
||||
|
||||
return_params: if True, will return a dictionary of parameters of the RNN.
|
||||
return_params: if True, will return a dictionary of parameters of the LSTM
|
||||
|
||||
memory_optimization: if enabled, the RNN step is recomputed on backward
|
||||
memory_optimization: if enabled, the LSTM step is recomputed on backward
|
||||
step so that we don't need to store forward activations for each
|
||||
timestep. Saves memory with cost of computation.
|
||||
|
||||
forget_bias: forget gate bias (default 0.0).
|
||||
forget_bias: forget gate bias (default 0.0)
|
||||
|
||||
forward_only: whether to create a backward pass.
|
||||
forward_only: whether to create a backward pass
|
||||
|
||||
drop_states: drop invalid states, passed to LSTMUnit/GRUUnit operator.
|
||||
drop_states: drop invalid states, passed through to LSTMUnit operator
|
||||
|
||||
return_last_layer_only: only return outputs from final layer
|
||||
(so that length of results does depend on number of layers).
|
||||
(so that length of results does depend on number of layers)
|
||||
|
||||
static_rnn_unroll_size: if not None, we will use static RNN which is
|
||||
unrolled into Caffe2 graph. The size of the unroll is the value of
|
||||
this parameter.
|
||||
|
||||
no_cell_state: whether there is cell state in the model.
|
||||
It is False for LSTM and True for GRU.
|
||||
unrolled into Caffe2 graph. The size of the unroll is the value of
|
||||
this parameter.
|
||||
'''
|
||||
if type(dim_out) is not list and type(dim_out) is not tuple:
|
||||
dim_out = [dim_out]
|
||||
@ -1014,24 +1007,19 @@ def _RNN(
|
||||
shape=[dim_out[i]],
|
||||
initializer=Initializer('ConstantFill', value=0.0),
|
||||
)
|
||||
initial_states.append(initial_hidden)
|
||||
if not no_cell_state:
|
||||
initial_cell = model.create_param(
|
||||
'initial_cell_state' + suffix,
|
||||
shape=[dim_out[i]],
|
||||
initializer=Initializer('ConstantFill', value=0.0),
|
||||
)
|
||||
initial_states.append(initial_cell)
|
||||
initial_cell = model.create_param(
|
||||
'initial_cell_state' + suffix,
|
||||
shape=[dim_out[i]],
|
||||
initializer=Initializer('ConstantFill', value=0.0),
|
||||
)
|
||||
initial_states.extend([initial_hidden, initial_cell])
|
||||
|
||||
num_states = 1 if no_cell_state else 2
|
||||
assert len(initial_states) == num_states * num_layers, \
|
||||
"Incorrect initial_states," \
|
||||
+ " was expecting {} elements".format(num_states * num_layers) \
|
||||
+ " but got {} elements".format(len(initial_states))
|
||||
assert len(initial_states) == 2 * num_layers, \
|
||||
"Incorrect initial_states, was expecting 2 * num_layers elements" \
|
||||
+ " but had only {}".format(len(initial_states))
|
||||
|
||||
# outputs_with_grads argument indexes into final layer
|
||||
outputs_with_grads = [
|
||||
2 * num_states * (num_layers - 1) + i for i in outputs_with_grads]
|
||||
outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
|
||||
_, result = cell.apply_over_sequence(
|
||||
model=model,
|
||||
inputs=input_blob,
|
||||
@ -1041,7 +1029,7 @@ def _RNN(
|
||||
)
|
||||
|
||||
if return_last_layer_only:
|
||||
result = result[2 * num_states * (num_layers - 1):]
|
||||
result = result[4 * (num_layers - 1):]
|
||||
if return_params:
|
||||
result = list(result) + [{
|
||||
'input': cell.get_input_params(),
|
||||
@ -1050,8 +1038,8 @@ def _RNN(
|
||||
return tuple(result)
|
||||
|
||||
|
||||
LSTM = functools.partial(_RNN, LSTMCell)
|
||||
MILSTM = functools.partial(_RNN, MILSTMCell)
|
||||
LSTM = functools.partial(_LSTM, LSTMCell)
|
||||
MILSTM = functools.partial(_LSTM, MILSTMCell)
|
||||
|
||||
|
||||
class UnrolledCell(RNNCell):
|
||||
@ -1084,7 +1072,7 @@ class UnrolledCell(RNNCell):
|
||||
scope_name = "timestep_{}".format(t)
|
||||
# Parameters of all timesteps are shared
|
||||
with ParameterSharing({scope_name: ''}),\
|
||||
scope.NameScope(scope_name):
|
||||
scope.NameScope(scope_name):
|
||||
timestep = model.param_init_net.ConstantFill(
|
||||
[], "timestep", value=t, shape=[1],
|
||||
dtype=core.DataType.INT32,
|
||||
@ -1222,10 +1210,10 @@ def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
|
||||
|
||||
# Multiply by 4 since we have 4 gates per LSTM unit
|
||||
first_layer_sz = input_weight_size + recurrent_weight_size + \
|
||||
input_bias_size + recurrent_bias_size
|
||||
input_bias_size + recurrent_bias_size
|
||||
upper_layer_sz = upper_layer_input_weight_size + \
|
||||
recurrent_weight_size + input_bias_size + \
|
||||
recurrent_bias_size
|
||||
recurrent_weight_size + input_bias_size + \
|
||||
recurrent_bias_size
|
||||
total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
|
||||
|
||||
weights = model.create_param(
|
||||
@ -1261,8 +1249,7 @@ def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
|
||||
p = {}
|
||||
for pname in weight_params + bias_params:
|
||||
for j in range(0, num_layers):
|
||||
values = p[pname] if pname in p else init(
|
||||
j, pname, input_type)
|
||||
values = p[pname] if pname in p else init(j, pname, input_type)
|
||||
model.param_init_net.RecurrentParamSet(
|
||||
[input_blob, weights, values],
|
||||
weights,
|
||||
|
||||
Reference in New Issue
Block a user