DropoutCell as wrapper for another RNNCell

Summary: Added a new RNNCell, DropoutCell, which wraps an existing RNNCell and applies dropout to its primary output (as defined by get_output_state_index()).

Reviewed By: salexspb

Differential Revision: D5084871

fbshipit-source-id: 60474af84e5757a12e7fdc3814840dc9ba8e32a1
This commit is contained in:
James Cross
2017-05-24 11:22:35 -07:00
committed by Facebook Github Bot
parent c55be38e63
commit 03503140fd

View File

@ -26,11 +26,10 @@ class RNNCell(object):
As a result base class will provice apply_over_sequence method, which
allows you to apply recurrent operations over a sequence of any length.
'''
def __init__(self, name, dropout_ratio=None, forward_only=False):
def __init__(self, name, forward_only=False):
self.name = name
self.recompute_blobs = []
self.forward_only = forward_only
self.dropout_ratio = dropout_ratio
def scope(self, name):
return self.name + '/' + name if self.name is not None else name
@ -149,31 +148,27 @@ class RNNCell(object):
'''
raise NotImplementedError('Abstract method')
def get_output_dim(self):
'''
Specifies the dimension (number of units) of stepwise output.
'''
raise NotImplementedError('Abstract method')
def _prepare_output(self, model, states):
output = states[self.get_output_state_index()]
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
return output
'''
Allows arbitrary post-processing of primary output.
'''
return states[self.get_output_state_index()]
def _prepare_output_sequence(self, model, states):
output_state_index = 2 * self.get_output_state_index()
output = states[output_state_index]
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
return output
def _prepare_output_sequence(self, model, state_outputs):
'''
Allows arbitrary post-processing of primary sequence output.
def _apply_dropout(self, model, output):
with core.NameScope(self.name or ''):
output, _ = model.net.Dropout(
output,
[
str(output) + '_with_dropout',
str(output) + '_dropout_mask',
],
ratio=float(self.dropout_ratio),
is_test=int(self.forward_only),
)
return output
(Note that state_outputs alternates between full-sequence and final
output for each state, thus the index multiplier 2.)
'''
output_sequence_index = 2 * self.get_output_state_index()
return state_outputs[output_sequence_index]
class LSTMCell(RNNCell):
@ -272,7 +267,7 @@ class LSTMCell(RNNCell):
def get_state_names(self):
return (self.scope('hidden_t'), self.scope('cell_t'))
def get_output_size(self):
def get_output_dim(self):
return self.hidden_size
@ -382,6 +377,67 @@ class MILSTMCell(LSTMCell):
return hidden_t, cell_t
class DropoutCell(RNNCell):
'''
Wraps arbitrary RNNCell, applying dropout to its output (but not to the
recurrent connection for the corresponding state).
'''
def __init__(self, internal_cell, dropout_ratio=None, **kwargs):
self.internal_cell = internal_cell
self.dropout_ratio = dropout_ratio
super(DropoutCell, self).__init__(**kwargs)
self.prepare_input = internal_cell.prepare_input
self.get_output_state_index = internal_cell.get_output_state_index
self.get_state_names = internal_cell.get_state_names
self.get_output_dim = internal_cell.get_output_dim
def _apply(
self,
model,
input_t,
seq_lengths,
states,
timestep,
extra_inputs=None,
):
return self.internal_cell._apply(
model,
input_t,
seq_lengths,
states,
timestep,
extra_inputs,
)
def _prepare_output(self, model, states):
output = states[self.get_output_state_index()]
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
return output
def _prepare_output_sequence(self, model, state_outputs):
output_sequence_index = 2 * self.get_output_state_index()
output = state_outputs[output_sequence_index]
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
return output
def _apply_dropout(self, model, output):
if self.dropout_ratio and not self.forward_only:
with core.NameScope(self.name or ''):
output, _ = model.net.Dropout(
output,
[
str(output) + '_with_dropout',
str(output) + '_dropout_mask',
],
ratio=float(self.dropout_ratio),
)
return output
class MultiRNNCell(RNNCell):
'''
Multilayer RNN via the composition of RNNCell instance.
@ -505,8 +561,6 @@ class MultiRNNCell(RNNCell):
states[-len(self.cells[-1].get_state_names()):],
)
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
if (len(self.cells) - 1) in self.residual_output_layers:
last_layer_input_index = 0
for cell in self.cells[:-2]:
@ -526,9 +580,6 @@ class MultiRNNCell(RNNCell):
states[-(2 * len(self.cells[-1].get_state_names())):],
)
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
if (len(self.cells) - 1) in self.residual_output_layers:
last_layer_input_index = 0
for cell in self.cells[:-2]:
@ -578,10 +629,13 @@ class AttentionCell(RNNCell):
seq_lengths,
states,
timestep,
extra_inputs=None,
):
decoder_prev_states = states[:-1]
attention_weighted_encoder_context_t_prev = states[-1]
assert extra_inputs is None
decoder_states = self.decoder_cell._apply(
model,
input_t,
@ -674,7 +728,7 @@ class AttentionCell(RNNCell):
state_names.append(self.scope('attention_weighted_encoder_context_t'))
return state_names
def get_output_size(self):
def get_output_dim(self):
return self.decoder_state_dim + self.encoder_output_dim
def get_output_state_index(self):
@ -691,22 +745,20 @@ class AttentionCell(RNNCell):
],
axis=2,
)
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
return output
def _prepare_output_sequence(self, model, states):
def _prepare_output_sequence(self, model, state_outputs):
decoder_output = self.decoder_cell._prepare_output_sequence(
model,
states[:-2],
state_outputs[:-2],
)
attention_context_index = 2 * (len(self.get_state_names()) - 1)
with core.NameScope(self.name or ''):
output, _ = model.net.Concat(
[
decoder_output,
states[attention_context_index],
state_outputs[attention_context_index],
],
[
'states_and_context_combination',
@ -714,9 +766,6 @@ class AttentionCell(RNNCell):
],
axis=2,
)
if self.dropout_ratio is not None:
output = self._apply_dropout(model, output)
return output