LSTM: support dropping hidden / cell states when sequence

Summary:
This is useful when data has standalone sequences which are
not connected to each other by any meaningful context

Reviewed By: yqwangustc

Differential Revision: D4835164

fbshipit-source-id: f95626acc26acc3eba3bca7efb08ed1dbdb36c83
This commit is contained in:
Alexander Sidorov
2017-04-27 11:34:45 -07:00
committed by Facebook Github Bot
parent c4ce118393
commit ad6204eb0b
7 changed files with 105 additions and 34 deletions

View File

@ -129,12 +129,14 @@ class LSTMCell(RNNCell):
memory_optimization,
name,
forward_only=False,
drop_states=False,
):
super(LSTMCell, self).__init__(name, forward_only)
self.input_size = input_size
self.hidden_size = hidden_size
self.forget_bias = float(forget_bias)
self.memory_optimization = memory_optimization
self.drop_states = drop_states
def _apply(
self,
@ -163,6 +165,7 @@ class LSTMCell(RNNCell):
],
list(self.get_state_names()),
forget_bias=self.forget_bias,
drop_states=self.drop_states,
)
model.net.AddExternalOutputs(hidden_t, cell_t)
if self.memory_optimization:
@ -202,7 +205,8 @@ class LSTMCell(RNNCell):
def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope, outputs_with_grads=(0,), return_params=False,
memory_optimization=False, forget_bias=0.0, forward_only=False):
memory_optimization=False, forget_bias=0.0, forward_only=False,
drop_states=False):
'''
Adds a standard LSTM recurrent network operator to a model.
@ -241,6 +245,7 @@ def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
memory_optimization=memory_optimization,
name=scope,
forward_only=forward_only,
drop_states=drop_states,
)
result = cell.apply_over_sequence(
model=model,
@ -795,6 +800,7 @@ class MILSTMCell(LSTMCell):
[hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
[self.scope('hidden_t_intermediate'), self.scope('cell_t')],
forget_bias=self.forget_bias,
drop_states=self.drop_states,
)
hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t'))
model.net.AddExternalOutputs(
@ -808,7 +814,7 @@ class MILSTMCell(LSTMCell):
def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope, outputs_with_grads=(0,), memory_optimization=False,
forget_bias=0.0, forward_only=False):
forget_bias=0.0, forward_only=False, drop_states=False):
'''
Adds MI flavor of standard LSTM recurrent network operator to a model.
See https://arxiv.org/pdf/1606.06630.pdf
@ -831,9 +837,9 @@ def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
outputs_with_grads : position indices of output blobs which will receive
external error gradient during backpropagation
memory_optimization: if enabled, the LSTM step is recomputed on backward step
so that we don't need to store forward activations for each
timestep. Saves memory with cost of computation.
memory_optimization: if enabled, the LSTM step is recomputed on backward
step. So that we don't need to store forward activations for each timestep.
Saves memory with cost of computation.
forward_only run only forward pass
'''
@ -844,6 +850,7 @@ def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
memory_optimization=memory_optimization,
name=scope,
forward_only=forward_only,
drop_states=drop_states,
)
result = cell.apply_over_sequence(
model=model,