mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
DropoutCell as wrapper for another RNNCell
Summary: Added a new RNNCell, DropoutCell, which wraps an existing RNNCell and applies dropout to its primary output (as defined by get_output_state_index()). Reviewed By: salexspb Differential Revision: D5084871 fbshipit-source-id: 60474af84e5757a12e7fdc3814840dc9ba8e32a1
This commit is contained in:
committed by
Facebook Github Bot
parent
c55be38e63
commit
03503140fd
@ -26,11 +26,10 @@ class RNNCell(object):
|
||||
As a result base class will provice apply_over_sequence method, which
|
||||
allows you to apply recurrent operations over a sequence of any length.
|
||||
'''
|
||||
def __init__(self, name, dropout_ratio=None, forward_only=False):
|
||||
def __init__(self, name, forward_only=False):
|
||||
self.name = name
|
||||
self.recompute_blobs = []
|
||||
self.forward_only = forward_only
|
||||
self.dropout_ratio = dropout_ratio
|
||||
|
||||
def scope(self, name):
|
||||
return self.name + '/' + name if self.name is not None else name
|
||||
@ -149,31 +148,27 @@ class RNNCell(object):
|
||||
'''
|
||||
raise NotImplementedError('Abstract method')
|
||||
|
||||
def get_output_dim(self):
|
||||
'''
|
||||
Specifies the dimension (number of units) of stepwise output.
|
||||
'''
|
||||
raise NotImplementedError('Abstract method')
|
||||
|
||||
def _prepare_output(self, model, states):
|
||||
output = states[self.get_output_state_index()]
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
return output
|
||||
'''
|
||||
Allows arbitrary post-processing of primary output.
|
||||
'''
|
||||
return states[self.get_output_state_index()]
|
||||
|
||||
def _prepare_output_sequence(self, model, states):
|
||||
output_state_index = 2 * self.get_output_state_index()
|
||||
output = states[output_state_index]
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
return output
|
||||
def _prepare_output_sequence(self, model, state_outputs):
|
||||
'''
|
||||
Allows arbitrary post-processing of primary sequence output.
|
||||
|
||||
def _apply_dropout(self, model, output):
|
||||
with core.NameScope(self.name or ''):
|
||||
output, _ = model.net.Dropout(
|
||||
output,
|
||||
[
|
||||
str(output) + '_with_dropout',
|
||||
str(output) + '_dropout_mask',
|
||||
],
|
||||
ratio=float(self.dropout_ratio),
|
||||
is_test=int(self.forward_only),
|
||||
)
|
||||
return output
|
||||
(Note that state_outputs alternates between full-sequence and final
|
||||
output for each state, thus the index multiplier 2.)
|
||||
'''
|
||||
output_sequence_index = 2 * self.get_output_state_index()
|
||||
return state_outputs[output_sequence_index]
|
||||
|
||||
|
||||
class LSTMCell(RNNCell):
|
||||
@ -272,7 +267,7 @@ class LSTMCell(RNNCell):
|
||||
def get_state_names(self):
|
||||
return (self.scope('hidden_t'), self.scope('cell_t'))
|
||||
|
||||
def get_output_size(self):
|
||||
def get_output_dim(self):
|
||||
return self.hidden_size
|
||||
|
||||
|
||||
@ -382,6 +377,67 @@ class MILSTMCell(LSTMCell):
|
||||
return hidden_t, cell_t
|
||||
|
||||
|
||||
class DropoutCell(RNNCell):
|
||||
'''
|
||||
Wraps arbitrary RNNCell, applying dropout to its output (but not to the
|
||||
recurrent connection for the corresponding state).
|
||||
'''
|
||||
|
||||
def __init__(self, internal_cell, dropout_ratio=None, **kwargs):
|
||||
self.internal_cell = internal_cell
|
||||
self.dropout_ratio = dropout_ratio
|
||||
super(DropoutCell, self).__init__(**kwargs)
|
||||
|
||||
self.prepare_input = internal_cell.prepare_input
|
||||
self.get_output_state_index = internal_cell.get_output_state_index
|
||||
self.get_state_names = internal_cell.get_state_names
|
||||
self.get_output_dim = internal_cell.get_output_dim
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
model,
|
||||
input_t,
|
||||
seq_lengths,
|
||||
states,
|
||||
timestep,
|
||||
extra_inputs=None,
|
||||
):
|
||||
return self.internal_cell._apply(
|
||||
model,
|
||||
input_t,
|
||||
seq_lengths,
|
||||
states,
|
||||
timestep,
|
||||
extra_inputs,
|
||||
)
|
||||
|
||||
def _prepare_output(self, model, states):
|
||||
output = states[self.get_output_state_index()]
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
return output
|
||||
|
||||
def _prepare_output_sequence(self, model, state_outputs):
|
||||
output_sequence_index = 2 * self.get_output_state_index()
|
||||
output = state_outputs[output_sequence_index]
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
return output
|
||||
|
||||
def _apply_dropout(self, model, output):
|
||||
if self.dropout_ratio and not self.forward_only:
|
||||
with core.NameScope(self.name or ''):
|
||||
output, _ = model.net.Dropout(
|
||||
output,
|
||||
[
|
||||
str(output) + '_with_dropout',
|
||||
str(output) + '_dropout_mask',
|
||||
],
|
||||
ratio=float(self.dropout_ratio),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class MultiRNNCell(RNNCell):
|
||||
'''
|
||||
Multilayer RNN via the composition of RNNCell instance.
|
||||
@ -505,8 +561,6 @@ class MultiRNNCell(RNNCell):
|
||||
states[-len(self.cells[-1].get_state_names()):],
|
||||
)
|
||||
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
if (len(self.cells) - 1) in self.residual_output_layers:
|
||||
last_layer_input_index = 0
|
||||
for cell in self.cells[:-2]:
|
||||
@ -526,9 +580,6 @@ class MultiRNNCell(RNNCell):
|
||||
states[-(2 * len(self.cells[-1].get_state_names())):],
|
||||
)
|
||||
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
|
||||
if (len(self.cells) - 1) in self.residual_output_layers:
|
||||
last_layer_input_index = 0
|
||||
for cell in self.cells[:-2]:
|
||||
@ -578,10 +629,13 @@ class AttentionCell(RNNCell):
|
||||
seq_lengths,
|
||||
states,
|
||||
timestep,
|
||||
extra_inputs=None,
|
||||
):
|
||||
decoder_prev_states = states[:-1]
|
||||
attention_weighted_encoder_context_t_prev = states[-1]
|
||||
|
||||
assert extra_inputs is None
|
||||
|
||||
decoder_states = self.decoder_cell._apply(
|
||||
model,
|
||||
input_t,
|
||||
@ -674,7 +728,7 @@ class AttentionCell(RNNCell):
|
||||
state_names.append(self.scope('attention_weighted_encoder_context_t'))
|
||||
return state_names
|
||||
|
||||
def get_output_size(self):
|
||||
def get_output_dim(self):
|
||||
return self.decoder_state_dim + self.encoder_output_dim
|
||||
|
||||
def get_output_state_index(self):
|
||||
@ -691,22 +745,20 @@ class AttentionCell(RNNCell):
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
|
||||
return output
|
||||
|
||||
def _prepare_output_sequence(self, model, states):
|
||||
def _prepare_output_sequence(self, model, state_outputs):
|
||||
decoder_output = self.decoder_cell._prepare_output_sequence(
|
||||
model,
|
||||
states[:-2],
|
||||
state_outputs[:-2],
|
||||
)
|
||||
|
||||
attention_context_index = 2 * (len(self.get_state_names()) - 1)
|
||||
with core.NameScope(self.name or ''):
|
||||
output, _ = model.net.Concat(
|
||||
[
|
||||
decoder_output,
|
||||
states[attention_context_index],
|
||||
state_outputs[attention_context_index],
|
||||
],
|
||||
[
|
||||
'states_and_context_combination',
|
||||
@ -714,9 +766,6 @@ class AttentionCell(RNNCell):
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
if self.dropout_ratio is not None:
|
||||
output = self._apply_dropout(model, output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user