mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied. - #94587 - #94588 - #94592 Also, methods with only a `super()` call are removed: ```diff class MyModule(nn.Module): - def __init__(self): - super().__init__() - def forward(self, ...): ... ``` Some cases that change the semantics should be kept unchanged. E.g.:f152a79be9/caffe2/python/net_printer.py (L184-L190)
f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94587 Approved by: https://github.com/ezyang
1979 lines
66 KiB
Python
1979 lines
66 KiB
Python
## @package rnn_cell
|
|
# Module caffe2.python.rnn_cell
|
|
|
|
|
|
|
|
|
|
|
|
import functools
|
|
import inspect
|
|
import logging
|
|
import numpy as np
|
|
import random
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python.attention import (
|
|
apply_dot_attention,
|
|
apply_recurrent_attention,
|
|
apply_regular_attention,
|
|
apply_soft_coverage_attention,
|
|
AttentionType,
|
|
)
|
|
from caffe2.python import core, recurrent, workspace, brew, scope, utils
|
|
from caffe2.python.modeling.parameter_sharing import ParameterSharing
|
|
from caffe2.python.modeling.parameter_info import ParameterTags
|
|
from caffe2.python.modeling.initializers import Initializer
|
|
from caffe2.python.model_helper import ModelHelper
|
|
|
|
|
|
def _RectifyName(blob_reference_or_name):
|
|
if blob_reference_or_name is None:
|
|
return None
|
|
if isinstance(blob_reference_or_name, str):
|
|
return core.ScopedBlobReference(blob_reference_or_name)
|
|
if not isinstance(blob_reference_or_name, core.BlobReference):
|
|
raise Exception("Unknown blob reference type")
|
|
return blob_reference_or_name
|
|
|
|
|
|
def _RectifyNames(blob_references_or_names):
|
|
if blob_references_or_names is None:
|
|
return None
|
|
return [_RectifyName(i) for i in blob_references_or_names]
|
|
|
|
|
|
class RNNCell:
|
|
'''
|
|
Base class for writing recurrent / stateful operations.
|
|
|
|
One needs to implement 2 methods: apply_override
|
|
and get_state_names_override.
|
|
|
|
As a result base class will provice apply_over_sequence method, which
|
|
allows you to apply recurrent operations over a sequence of any length.
|
|
|
|
As optional you could add input and output preparation steps by overriding
|
|
corresponding methods.
|
|
'''
|
|
def __init__(self, name=None, forward_only=False, initializer=None):
|
|
self.name = name
|
|
self.recompute_blobs = []
|
|
self.forward_only = forward_only
|
|
self._initializer = initializer
|
|
|
|
@property
|
|
def initializer(self):
|
|
return self._initializer
|
|
|
|
@initializer.setter
|
|
def initializer(self, value):
|
|
self._initializer = value
|
|
|
|
def scope(self, name):
|
|
return self.name + '/' + name if self.name is not None else name
|
|
|
|
def apply_over_sequence(
|
|
self,
|
|
model,
|
|
inputs,
|
|
seq_lengths=None,
|
|
initial_states=None,
|
|
outputs_with_grads=None,
|
|
):
|
|
if initial_states is None:
|
|
with scope.NameScope(self.name):
|
|
if self.initializer is None:
|
|
raise Exception("Either initial states "
|
|
"or initializer have to be set")
|
|
initial_states = self.initializer.create_states(model)
|
|
|
|
preprocessed_inputs = self.prepare_input(model, inputs)
|
|
step_model = ModelHelper(name=self.name, param_model=model)
|
|
input_t, timestep = step_model.net.AddScopedExternalInputs(
|
|
'input_t',
|
|
'timestep',
|
|
)
|
|
utils.raiseIfNotEqual(
|
|
len(initial_states), len(self.get_state_names()),
|
|
"Number of initial state values provided doesn't match the number "
|
|
"of states"
|
|
)
|
|
states_prev = step_model.net.AddScopedExternalInputs(*[
|
|
s + '_prev' for s in self.get_state_names()
|
|
])
|
|
states = self._apply(
|
|
model=step_model,
|
|
input_t=input_t,
|
|
seq_lengths=seq_lengths,
|
|
states=states_prev,
|
|
timestep=timestep,
|
|
)
|
|
|
|
external_outputs = set(step_model.net.Proto().external_output)
|
|
for state in states:
|
|
if state not in external_outputs:
|
|
step_model.net.AddExternalOutput(state)
|
|
|
|
if outputs_with_grads is None:
|
|
outputs_with_grads = [self.get_output_state_index() * 2]
|
|
|
|
# states_for_all_steps consists of combination of
|
|
# states gather for all steps and final states. It looks like this:
|
|
# (state_1_all, state_1_final, state_2_all, state_2_final, ...)
|
|
states_for_all_steps = recurrent.recurrent_net(
|
|
net=model.net,
|
|
cell_net=step_model.net,
|
|
inputs=[(input_t, preprocessed_inputs)],
|
|
initial_cell_inputs=list(zip(states_prev, initial_states)),
|
|
links=dict(zip(states_prev, states)),
|
|
timestep=timestep,
|
|
scope=self.name,
|
|
forward_only=self.forward_only,
|
|
outputs_with_grads=outputs_with_grads,
|
|
recompute_blobs_on_backward=self.recompute_blobs,
|
|
)
|
|
|
|
output = self._prepare_output_sequence(
|
|
model,
|
|
states_for_all_steps,
|
|
)
|
|
return output, states_for_all_steps
|
|
|
|
def apply(self, model, input_t, seq_lengths, states, timestep):
|
|
input_t = self.prepare_input(model, input_t)
|
|
states = self._apply(
|
|
model, input_t, seq_lengths, states, timestep)
|
|
output = self._prepare_output(model, states)
|
|
return output, states
|
|
|
|
def _apply(
|
|
self,
|
|
model, input_t, seq_lengths, states, timestep, extra_inputs=None
|
|
):
|
|
'''
|
|
This method uses apply_override provided by a custom cell.
|
|
On the top it takes care of applying self.scope() to all the outputs.
|
|
While all the inputs stay within the scope this function was called
|
|
from.
|
|
'''
|
|
args = self._rectify_apply_inputs(
|
|
input_t, seq_lengths, states, timestep, extra_inputs)
|
|
with core.NameScope(self.name):
|
|
return self.apply_override(model, *args)
|
|
|
|
def _rectify_apply_inputs(
|
|
self, input_t, seq_lengths, states, timestep, extra_inputs):
|
|
'''
|
|
Before applying a scope we make sure that all external blob names
|
|
are converted to blob reference. So further scoping doesn't affect them
|
|
'''
|
|
|
|
input_t, seq_lengths, timestep = _RectifyNames(
|
|
[input_t, seq_lengths, timestep])
|
|
states = _RectifyNames(states)
|
|
if extra_inputs:
|
|
extra_input_names, extra_input_sizes = zip(*extra_inputs)
|
|
extra_inputs = _RectifyNames(extra_input_names)
|
|
extra_inputs = zip(extra_input_names, extra_input_sizes)
|
|
|
|
arg_names = inspect.getargspec(self.apply_override).args
|
|
rectified = [input_t, seq_lengths, states, timestep]
|
|
if 'extra_inputs' in arg_names:
|
|
rectified.append(extra_inputs)
|
|
return rectified
|
|
|
|
|
|
def apply_override(
|
|
self,
|
|
model, input_t, seq_lengths, timestep, extra_inputs=None,
|
|
):
|
|
'''
|
|
A single step of a recurrent network to be implemented by each custom
|
|
RNNCell.
|
|
|
|
model: ModelHelper object new operators would be added to
|
|
|
|
input_t: singlse input with shape (1, batch_size, input_dim)
|
|
|
|
seq_lengths: blob containing sequence lengths which would be passed to
|
|
LSTMUnit operator
|
|
|
|
states: previous recurrent states
|
|
|
|
timestep: current recurrent iteration. Could be used together with
|
|
seq_lengths in order to determine, if some shorter sequences
|
|
in the batch have already ended.
|
|
|
|
extra_inputs: list of tuples (input, dim). specifies additional input
|
|
which is not subject to prepare_input(). (useful when a cell is a
|
|
component of a larger recurrent structure, e.g., attention)
|
|
'''
|
|
raise NotImplementedError('Abstract method')
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
'''
|
|
If some operations in _apply method depend only on the input,
|
|
not on recurrent states, they could be computed in advance.
|
|
|
|
model: ModelHelper object new operators would be added to
|
|
|
|
input_blob: either the whole input sequence with shape
|
|
(sequence_length, batch_size, input_dim) or a single input with shape
|
|
(1, batch_size, input_dim).
|
|
'''
|
|
return input_blob
|
|
|
|
def get_output_state_index(self):
|
|
'''
|
|
Return index into state list of the "primary" step-wise output.
|
|
'''
|
|
return 0
|
|
|
|
def get_state_names(self):
|
|
'''
|
|
Returns recurrent state names with self.name scoping applied
|
|
'''
|
|
return [self.scope(name) for name in self.get_state_names_override()]
|
|
|
|
def get_state_names_override(self):
|
|
'''
|
|
Override this function in your custom cell.
|
|
It should return the names of the recurrent states.
|
|
|
|
It's required by apply_over_sequence method in order to allocate
|
|
recurrent states for all steps with meaningful names.
|
|
'''
|
|
raise NotImplementedError('Abstract method')
|
|
|
|
def get_output_dim(self):
|
|
'''
|
|
Specifies the dimension (number of units) of stepwise output.
|
|
'''
|
|
raise NotImplementedError('Abstract method')
|
|
|
|
def _prepare_output(self, model, states):
|
|
'''
|
|
Allows arbitrary post-processing of primary output.
|
|
'''
|
|
return states[self.get_output_state_index()]
|
|
|
|
def _prepare_output_sequence(self, model, state_outputs):
|
|
'''
|
|
Allows arbitrary post-processing of primary sequence output.
|
|
|
|
(Note that state_outputs alternates between full-sequence and final
|
|
output for each state, thus the index multiplier 2.)
|
|
'''
|
|
output_sequence_index = 2 * self.get_output_state_index()
|
|
return state_outputs[output_sequence_index]
|
|
|
|
|
|
class LSTMInitializer:
|
|
def __init__(self, hidden_size):
|
|
self.hidden_size = hidden_size
|
|
|
|
def create_states(self, model):
|
|
return [
|
|
model.create_param(
|
|
param_name='initial_hidden_state',
|
|
initializer=Initializer(operator_name='ConstantFill',
|
|
value=0.0),
|
|
shape=[self.hidden_size],
|
|
),
|
|
model.create_param(
|
|
param_name='initial_cell_state',
|
|
initializer=Initializer(operator_name='ConstantFill',
|
|
value=0.0),
|
|
shape=[self.hidden_size],
|
|
)
|
|
]
|
|
|
|
|
|
# based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
|
|
class BasicRNNCell(RNNCell):
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
hidden_size,
|
|
forget_bias,
|
|
memory_optimization,
|
|
drop_states=False,
|
|
initializer=None,
|
|
activation=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.drop_states = drop_states
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.activation = activation
|
|
|
|
if self.activation not in ['relu', 'tanh']:
|
|
raise RuntimeError(
|
|
'BasicRNNCell with unknown activation function (%s)'
|
|
% self.activation)
|
|
|
|
def apply_override(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
hidden_t_prev = states[0]
|
|
|
|
gates_t = brew.fc(
|
|
model,
|
|
hidden_t_prev,
|
|
'gates_t',
|
|
dim_in=self.hidden_size,
|
|
dim_out=self.hidden_size,
|
|
axis=2,
|
|
)
|
|
|
|
brew.sum(model, [gates_t, input_t], gates_t)
|
|
if self.activation == 'tanh':
|
|
hidden_t = model.net.Tanh(gates_t, 'hidden_t')
|
|
elif self.activation == 'relu':
|
|
hidden_t = model.net.Relu(gates_t, 'hidden_t')
|
|
else:
|
|
raise RuntimeError(
|
|
'BasicRNNCell with unknown activation function (%s)'
|
|
% self.activation)
|
|
|
|
if seq_lengths is not None:
|
|
# TODO If this codepath becomes popular, it may be worth
|
|
# taking a look at optimizing it - for now a simple
|
|
# implementation is used to round out compatibility with
|
|
# ONNX.
|
|
timestep = model.net.CopyFromCPUInput(
|
|
timestep, 'timestep_gpu')
|
|
valid_b = model.net.GT(
|
|
[seq_lengths, timestep], 'valid_b', broadcast=1)
|
|
invalid_b = model.net.LE(
|
|
[seq_lengths, timestep], 'invalid_b', broadcast=1)
|
|
valid = model.net.Cast(valid_b, 'valid', to='float')
|
|
invalid = model.net.Cast(invalid_b, 'invalid', to='float')
|
|
|
|
hidden_valid = model.net.Mul(
|
|
[hidden_t, valid],
|
|
'hidden_valid',
|
|
broadcast=1,
|
|
axis=1,
|
|
)
|
|
if self.drop_states:
|
|
hidden_t = hidden_valid
|
|
else:
|
|
hidden_invalid = model.net.Mul(
|
|
[hidden_t_prev, invalid],
|
|
'hidden_invalid',
|
|
broadcast=1, axis=1)
|
|
hidden_t = model.net.Add(
|
|
[hidden_valid, hidden_invalid], hidden_t)
|
|
return (hidden_t,)
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
return brew.fc(
|
|
model,
|
|
input_blob,
|
|
self.scope('i2h'),
|
|
dim_in=self.input_size,
|
|
dim_out=self.hidden_size,
|
|
axis=2,
|
|
)
|
|
|
|
def get_state_names(self):
|
|
return (self.scope('hidden_t'),)
|
|
|
|
def get_output_dim(self):
|
|
return self.hidden_size
|
|
|
|
|
|
class LSTMCell(RNNCell):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
hidden_size,
|
|
forget_bias,
|
|
memory_optimization,
|
|
drop_states=False,
|
|
initializer=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(initializer=initializer, **kwargs)
|
|
self.initializer = initializer or LSTMInitializer(
|
|
hidden_size=hidden_size)
|
|
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.forget_bias = float(forget_bias)
|
|
self.memory_optimization = memory_optimization
|
|
self.drop_states = drop_states
|
|
self.gates_size = 4 * self.hidden_size
|
|
|
|
def apply_override(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
hidden_t_prev, cell_t_prev = states
|
|
|
|
fc_input = hidden_t_prev
|
|
fc_input_dim = self.hidden_size
|
|
|
|
if extra_inputs is not None:
|
|
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
|
|
fc_input = brew.concat(
|
|
model,
|
|
[hidden_t_prev] + list(extra_input_blobs),
|
|
'gates_concatenated_input_t',
|
|
axis=2,
|
|
)
|
|
fc_input_dim += sum(extra_input_sizes)
|
|
|
|
gates_t = brew.fc(
|
|
model,
|
|
fc_input,
|
|
'gates_t',
|
|
dim_in=fc_input_dim,
|
|
dim_out=self.gates_size,
|
|
axis=2,
|
|
)
|
|
brew.sum(model, [gates_t, input_t], gates_t)
|
|
|
|
if seq_lengths is not None:
|
|
inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
|
|
else:
|
|
inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
|
|
|
|
hidden_t, cell_t = model.net.LSTMUnit(
|
|
inputs,
|
|
['hidden_state', 'cell_state'],
|
|
forget_bias=self.forget_bias,
|
|
drop_states=self.drop_states,
|
|
sequence_lengths=(seq_lengths is not None),
|
|
)
|
|
model.net.AddExternalOutputs(hidden_t, cell_t)
|
|
if self.memory_optimization:
|
|
self.recompute_blobs = [gates_t]
|
|
|
|
return hidden_t, cell_t
|
|
|
|
def get_input_params(self):
|
|
return {
|
|
'weights': self.scope('i2h') + '_w',
|
|
'biases': self.scope('i2h') + '_b',
|
|
}
|
|
|
|
def get_recurrent_params(self):
|
|
return {
|
|
'weights': self.scope('gates_t') + '_w',
|
|
'biases': self.scope('gates_t') + '_b',
|
|
}
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
return brew.fc(
|
|
model,
|
|
input_blob,
|
|
self.scope('i2h'),
|
|
dim_in=self.input_size,
|
|
dim_out=self.gates_size,
|
|
axis=2,
|
|
)
|
|
|
|
def get_state_names_override(self):
|
|
return ['hidden_t', 'cell_t']
|
|
|
|
def get_output_dim(self):
|
|
return self.hidden_size
|
|
|
|
|
|
class LayerNormLSTMCell(RNNCell):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
hidden_size,
|
|
forget_bias,
|
|
memory_optimization,
|
|
drop_states=False,
|
|
initializer=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(initializer=initializer, **kwargs)
|
|
self.initializer = initializer or LSTMInitializer(
|
|
hidden_size=hidden_size
|
|
)
|
|
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.forget_bias = float(forget_bias)
|
|
self.memory_optimization = memory_optimization
|
|
self.drop_states = drop_states
|
|
self.gates_size = 4 * self.hidden_size
|
|
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
hidden_t_prev, cell_t_prev = states
|
|
|
|
fc_input = hidden_t_prev
|
|
fc_input_dim = self.hidden_size
|
|
|
|
if extra_inputs is not None:
|
|
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
|
|
fc_input = brew.concat(
|
|
model,
|
|
[hidden_t_prev] + list(extra_input_blobs),
|
|
self.scope('gates_concatenated_input_t'),
|
|
axis=2,
|
|
)
|
|
fc_input_dim += sum(extra_input_sizes)
|
|
|
|
gates_t = brew.fc(
|
|
model,
|
|
fc_input,
|
|
self.scope('gates_t'),
|
|
dim_in=fc_input_dim,
|
|
dim_out=self.gates_size,
|
|
axis=2,
|
|
)
|
|
brew.sum(model, [gates_t, input_t], gates_t)
|
|
|
|
# brew.layer_norm call is only difference from LSTMCell
|
|
gates_t, _, _ = brew.layer_norm(
|
|
model,
|
|
self.scope('gates_t'),
|
|
self.scope('gates_t_norm'),
|
|
dim_in=self.gates_size,
|
|
axis=-1,
|
|
)
|
|
|
|
hidden_t, cell_t = model.net.LSTMUnit(
|
|
[
|
|
hidden_t_prev,
|
|
cell_t_prev,
|
|
gates_t,
|
|
seq_lengths,
|
|
timestep,
|
|
],
|
|
self.get_state_names(),
|
|
forget_bias=self.forget_bias,
|
|
drop_states=self.drop_states,
|
|
)
|
|
model.net.AddExternalOutputs(hidden_t, cell_t)
|
|
if self.memory_optimization:
|
|
self.recompute_blobs = [gates_t]
|
|
|
|
return hidden_t, cell_t
|
|
|
|
def get_input_params(self):
|
|
return {
|
|
'weights': self.scope('i2h') + '_w',
|
|
'biases': self.scope('i2h') + '_b',
|
|
}
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
return brew.fc(
|
|
model,
|
|
input_blob,
|
|
self.scope('i2h'),
|
|
dim_in=self.input_size,
|
|
dim_out=self.gates_size,
|
|
axis=2,
|
|
)
|
|
|
|
def get_state_names(self):
|
|
return (self.scope('hidden_t'), self.scope('cell_t'))
|
|
|
|
|
|
class MILSTMCell(LSTMCell):
|
|
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
hidden_t_prev, cell_t_prev = states
|
|
|
|
fc_input = hidden_t_prev
|
|
fc_input_dim = self.hidden_size
|
|
|
|
if extra_inputs is not None:
|
|
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
|
|
fc_input = brew.concat(
|
|
model,
|
|
[hidden_t_prev] + list(extra_input_blobs),
|
|
self.scope('gates_concatenated_input_t'),
|
|
axis=2,
|
|
)
|
|
fc_input_dim += sum(extra_input_sizes)
|
|
|
|
prev_t = brew.fc(
|
|
model,
|
|
fc_input,
|
|
self.scope('prev_t'),
|
|
dim_in=fc_input_dim,
|
|
dim_out=self.gates_size,
|
|
axis=2,
|
|
)
|
|
|
|
# defining initializers for MI parameters
|
|
alpha = model.create_param(
|
|
self.scope('alpha'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=1.0),
|
|
)
|
|
beta_h = model.create_param(
|
|
self.scope('beta1'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=1.0),
|
|
)
|
|
beta_i = model.create_param(
|
|
self.scope('beta2'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=1.0),
|
|
)
|
|
b = model.create_param(
|
|
self.scope('b'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=0.0),
|
|
)
|
|
|
|
# alpha * input_t + beta_h
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
|
|
[input_t, alpha, beta_h],
|
|
self.scope('alpha_by_input_t_plus_beta_h'),
|
|
axis=2,
|
|
)
|
|
# (alpha * input_t + beta_h) * prev_t =
|
|
# alpha * input_t * prev_t + beta_h * prev_t
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
|
|
[alpha_by_input_t_plus_beta_h, prev_t],
|
|
self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
|
|
)
|
|
# beta_i * input_t + b
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
|
|
[input_t, beta_i, b],
|
|
self.scope('beta_i_by_input_t_plus_b'),
|
|
axis=2,
|
|
)
|
|
# alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
gates_t = brew.sum(
|
|
model,
|
|
[alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
|
|
self.scope('gates_t')
|
|
)
|
|
hidden_t, cell_t = model.net.LSTMUnit(
|
|
[hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
|
|
[self.scope('hidden_t_intermediate'), self.scope('cell_t')],
|
|
forget_bias=self.forget_bias,
|
|
drop_states=self.drop_states,
|
|
)
|
|
model.net.AddExternalOutputs(
|
|
cell_t,
|
|
hidden_t,
|
|
)
|
|
if self.memory_optimization:
|
|
self.recompute_blobs = [gates_t]
|
|
return hidden_t, cell_t
|
|
|
|
|
|
class LayerNormMILSTMCell(LSTMCell):
|
|
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
hidden_t_prev, cell_t_prev = states
|
|
|
|
fc_input = hidden_t_prev
|
|
fc_input_dim = self.hidden_size
|
|
|
|
if extra_inputs is not None:
|
|
extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
|
|
fc_input = brew.concat(
|
|
model,
|
|
[hidden_t_prev] + list(extra_input_blobs),
|
|
self.scope('gates_concatenated_input_t'),
|
|
axis=2,
|
|
)
|
|
fc_input_dim += sum(extra_input_sizes)
|
|
|
|
prev_t = brew.fc(
|
|
model,
|
|
fc_input,
|
|
self.scope('prev_t'),
|
|
dim_in=fc_input_dim,
|
|
dim_out=self.gates_size,
|
|
axis=2,
|
|
)
|
|
|
|
# defining initializers for MI parameters
|
|
alpha = model.create_param(
|
|
self.scope('alpha'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=1.0),
|
|
)
|
|
beta_h = model.create_param(
|
|
self.scope('beta1'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=1.0),
|
|
)
|
|
beta_i = model.create_param(
|
|
self.scope('beta2'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=1.0),
|
|
)
|
|
b = model.create_param(
|
|
self.scope('b'),
|
|
shape=[self.gates_size],
|
|
initializer=Initializer('ConstantFill', value=0.0),
|
|
)
|
|
|
|
# alpha * input_t + beta_h
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
|
|
[input_t, alpha, beta_h],
|
|
self.scope('alpha_by_input_t_plus_beta_h'),
|
|
axis=2,
|
|
)
|
|
# (alpha * input_t + beta_h) * prev_t =
|
|
# alpha * input_t * prev_t + beta_h * prev_t
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
|
|
[alpha_by_input_t_plus_beta_h, prev_t],
|
|
self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
|
|
)
|
|
# beta_i * input_t + b
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
|
|
[input_t, beta_i, b],
|
|
self.scope('beta_i_by_input_t_plus_b'),
|
|
axis=2,
|
|
)
|
|
# alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
|
|
# Shape: [1, batch_size, 4 * hidden_size]
|
|
gates_t = brew.sum(
|
|
model,
|
|
[alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
|
|
self.scope('gates_t')
|
|
)
|
|
# brew.layer_norm call is only difference from MILSTMCell._apply
|
|
gates_t, _, _ = brew.layer_norm(
|
|
model,
|
|
self.scope('gates_t'),
|
|
self.scope('gates_t_norm'),
|
|
dim_in=self.gates_size,
|
|
axis=-1,
|
|
)
|
|
hidden_t, cell_t = model.net.LSTMUnit(
|
|
[hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
|
|
[self.scope('hidden_t_intermediate'), self.scope('cell_t')],
|
|
forget_bias=self.forget_bias,
|
|
drop_states=self.drop_states,
|
|
)
|
|
model.net.AddExternalOutputs(
|
|
cell_t,
|
|
hidden_t,
|
|
)
|
|
if self.memory_optimization:
|
|
self.recompute_blobs = [gates_t]
|
|
return hidden_t, cell_t
|
|
|
|
|
|
class DropoutCell(RNNCell):
|
|
'''
|
|
Wraps arbitrary RNNCell, applying dropout to its output (but not to the
|
|
recurrent connection for the corresponding state).
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
internal_cell,
|
|
dropout_ratio=None,
|
|
use_cudnn=False,
|
|
**kwargs
|
|
):
|
|
self.internal_cell = internal_cell
|
|
self.dropout_ratio = dropout_ratio
|
|
assert 'is_test' in kwargs, "Argument 'is_test' is required"
|
|
self.is_test = kwargs.pop('is_test')
|
|
self.use_cudnn = use_cudnn
|
|
super().__init__(**kwargs)
|
|
|
|
self.prepare_input = internal_cell.prepare_input
|
|
self.get_output_state_index = internal_cell.get_output_state_index
|
|
self.get_state_names = internal_cell.get_state_names
|
|
self.get_output_dim = internal_cell.get_output_dim
|
|
|
|
self.mask = 0
|
|
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
return self.internal_cell._apply(
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs,
|
|
)
|
|
|
|
def _prepare_output(self, model, states):
|
|
output = self.internal_cell._prepare_output(
|
|
model,
|
|
states,
|
|
)
|
|
if self.dropout_ratio is not None:
|
|
output = self._apply_dropout(model, output)
|
|
return output
|
|
|
|
def _prepare_output_sequence(self, model, state_outputs):
|
|
output = self.internal_cell._prepare_output_sequence(
|
|
model,
|
|
state_outputs,
|
|
)
|
|
if self.dropout_ratio is not None:
|
|
output = self._apply_dropout(model, output)
|
|
return output
|
|
|
|
def _apply_dropout(self, model, output):
|
|
if self.dropout_ratio and not self.forward_only:
|
|
with core.NameScope(self.name or ''):
|
|
output = brew.dropout(
|
|
model,
|
|
output,
|
|
str(output) + '_with_dropout_mask{}'.format(self.mask),
|
|
ratio=float(self.dropout_ratio),
|
|
is_test=self.is_test,
|
|
use_cudnn=self.use_cudnn,
|
|
)
|
|
self.mask += 1
|
|
return output
|
|
|
|
|
|
class MultiRNNCellInitializer:
|
|
def __init__(self, cells):
|
|
self.cells = cells
|
|
|
|
def create_states(self, model):
|
|
states = []
|
|
for i, cell in enumerate(self.cells):
|
|
if cell.initializer is None:
|
|
raise Exception("Either initial states "
|
|
"or initializer have to be set")
|
|
|
|
with core.NameScope("layer_{}".format(i)),\
|
|
core.NameScope(cell.name):
|
|
states.extend(cell.initializer.create_states(model))
|
|
return states
|
|
|
|
|
|
class MultiRNNCell(RNNCell):
|
|
'''
|
|
Multilayer RNN via the composition of RNNCell instance.
|
|
|
|
It is the responsibility of calling code to ensure the compatibility
|
|
of the successive layers in terms of input/output dimensiality, etc.,
|
|
and to ensure that their blobs do not have name conflicts, typically by
|
|
creating the cells with names that specify layer number.
|
|
|
|
Assumes first state (recurrent output) for each layer should be the input
|
|
to the next layer.
|
|
'''
|
|
|
|
def __init__(self, cells, residual_output_layers=None, **kwargs):
|
|
'''
|
|
cells: list of RNNCell instances, from input to output side.
|
|
|
|
name: string designating network component (for scoping)
|
|
|
|
residual_output_layers: list of indices of layers whose input will
|
|
be added elementwise to their output elementwise. (It is the
|
|
responsibility of the client code to ensure shape compatibility.)
|
|
Note that layer 0 (zero) cannot have residual output because of the
|
|
timing of prepare_input().
|
|
|
|
forward_only: used to construct inference-only network.
|
|
'''
|
|
super().__init__(**kwargs)
|
|
self.cells = cells
|
|
|
|
if residual_output_layers is None:
|
|
self.residual_output_layers = []
|
|
else:
|
|
self.residual_output_layers = residual_output_layers
|
|
|
|
output_index_per_layer = []
|
|
base_index = 0
|
|
for cell in self.cells:
|
|
output_index_per_layer.append(
|
|
base_index + cell.get_output_state_index(),
|
|
)
|
|
base_index += len(cell.get_state_names())
|
|
|
|
self.output_connected_layers = []
|
|
self.output_indices = []
|
|
for i in range(len(self.cells) - 1):
|
|
if (i + 1) in self.residual_output_layers:
|
|
self.output_connected_layers.append(i)
|
|
self.output_indices.append(output_index_per_layer[i])
|
|
else:
|
|
self.output_connected_layers = []
|
|
self.output_indices = []
|
|
self.output_connected_layers.append(len(self.cells) - 1)
|
|
self.output_indices.append(output_index_per_layer[-1])
|
|
|
|
self.state_names = []
|
|
for i, cell in enumerate(self.cells):
|
|
self.state_names.extend(
|
|
map(self.layer_scoper(i), cell.get_state_names())
|
|
)
|
|
|
|
self.initializer = MultiRNNCellInitializer(cells)
|
|
|
|
def layer_scoper(self, layer_id):
|
|
def helper(name):
|
|
return "{}/layer_{}/{}".format(self.name, layer_id, name)
|
|
return helper
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
input_blob = _RectifyName(input_blob)
|
|
with core.NameScope(self.name or ''):
|
|
return self.cells[0].prepare_input(model, input_blob)
|
|
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
'''
|
|
Because below we will do scoping across layers, we need
|
|
to make sure that string blob names are convereted to BlobReference
|
|
objects.
|
|
'''
|
|
|
|
input_t, seq_lengths, states, timestep, extra_inputs = \
|
|
self._rectify_apply_inputs(
|
|
input_t, seq_lengths, states, timestep, extra_inputs)
|
|
|
|
states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
|
|
assert len(states) == sum(states_per_layer)
|
|
|
|
next_states = []
|
|
states_index = 0
|
|
|
|
layer_input = input_t
|
|
for i, layer_cell in enumerate(self.cells):
|
|
# # If cells don't have different names we still
|
|
# take care of scoping
|
|
with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
|
|
num_states = states_per_layer[i]
|
|
layer_states = states[states_index:(states_index + num_states)]
|
|
states_index += num_states
|
|
|
|
if i > 0:
|
|
prepared_input = layer_cell.prepare_input(
|
|
model, layer_input)
|
|
else:
|
|
prepared_input = layer_input
|
|
|
|
layer_next_states = layer_cell._apply(
|
|
model,
|
|
prepared_input,
|
|
seq_lengths,
|
|
layer_states,
|
|
timestep,
|
|
extra_inputs=(None if i > 0 else extra_inputs),
|
|
)
|
|
# Since we're using here non-public method _apply,
|
|
# instead of apply, we have to manually extract output
|
|
# from states
|
|
if i != len(self.cells) - 1:
|
|
layer_output = layer_cell._prepare_output(
|
|
model,
|
|
layer_next_states,
|
|
)
|
|
if i > 0 and i in self.residual_output_layers:
|
|
layer_input = brew.sum(
|
|
model,
|
|
[layer_output, layer_input],
|
|
self.scope('residual_output_{}'.format(i)),
|
|
)
|
|
else:
|
|
layer_input = layer_output
|
|
|
|
next_states.extend(layer_next_states)
|
|
return next_states
|
|
|
|
def get_state_names(self):
|
|
return self.state_names
|
|
|
|
def get_output_state_index(self):
|
|
index = 0
|
|
for cell in self.cells[:-1]:
|
|
index += len(cell.get_state_names())
|
|
index += self.cells[-1].get_output_state_index()
|
|
return index
|
|
|
|
def _prepare_output(self, model, states):
|
|
connected_outputs = []
|
|
state_index = 0
|
|
for i, cell in enumerate(self.cells):
|
|
num_states = len(cell.get_state_names())
|
|
if i in self.output_connected_layers:
|
|
layer_states = states[state_index:state_index + num_states]
|
|
layer_output = cell._prepare_output(
|
|
model,
|
|
layer_states
|
|
)
|
|
connected_outputs.append(layer_output)
|
|
state_index += num_states
|
|
if len(connected_outputs) > 1:
|
|
output = brew.sum(
|
|
model,
|
|
connected_outputs,
|
|
self.scope('residual_output'),
|
|
)
|
|
else:
|
|
output = connected_outputs[0]
|
|
return output
|
|
|
|
def _prepare_output_sequence(self, model, states):
|
|
connected_outputs = []
|
|
state_index = 0
|
|
for i, cell in enumerate(self.cells):
|
|
num_states = 2 * len(cell.get_state_names())
|
|
if i in self.output_connected_layers:
|
|
layer_states = states[state_index:state_index + num_states]
|
|
layer_output = cell._prepare_output_sequence(
|
|
model,
|
|
layer_states
|
|
)
|
|
connected_outputs.append(layer_output)
|
|
state_index += num_states
|
|
if len(connected_outputs) > 1:
|
|
output = brew.sum(
|
|
model,
|
|
connected_outputs,
|
|
self.scope('residual_output_sequence'),
|
|
)
|
|
else:
|
|
output = connected_outputs[0]
|
|
return output
|
|
|
|
|
|
class AttentionCell(RNNCell):
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_output_dim,
|
|
encoder_outputs,
|
|
encoder_lengths,
|
|
decoder_cell,
|
|
decoder_state_dim,
|
|
attention_type,
|
|
weighted_encoder_outputs,
|
|
attention_memory_optimization,
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.encoder_output_dim = encoder_output_dim
|
|
self.encoder_outputs = encoder_outputs
|
|
self.encoder_lengths = encoder_lengths
|
|
self.decoder_cell = decoder_cell
|
|
self.decoder_state_dim = decoder_state_dim
|
|
self.weighted_encoder_outputs = weighted_encoder_outputs
|
|
self.encoder_outputs_transposed = None
|
|
assert attention_type in [
|
|
AttentionType.Regular,
|
|
AttentionType.Recurrent,
|
|
AttentionType.Dot,
|
|
AttentionType.SoftCoverage,
|
|
]
|
|
self.attention_type = attention_type
|
|
self.attention_memory_optimization = attention_memory_optimization
|
|
|
|
def _apply(
|
|
self,
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
states,
|
|
timestep,
|
|
extra_inputs=None,
|
|
):
|
|
if self.attention_type == AttentionType.SoftCoverage:
|
|
decoder_prev_states = states[:-2]
|
|
attention_weighted_encoder_context_t_prev = states[-2]
|
|
coverage_t_prev = states[-1]
|
|
else:
|
|
decoder_prev_states = states[:-1]
|
|
attention_weighted_encoder_context_t_prev = states[-1]
|
|
|
|
assert extra_inputs is None
|
|
|
|
decoder_states = self.decoder_cell._apply(
|
|
model,
|
|
input_t,
|
|
seq_lengths,
|
|
decoder_prev_states,
|
|
timestep,
|
|
extra_inputs=[(
|
|
attention_weighted_encoder_context_t_prev,
|
|
self.encoder_output_dim,
|
|
)],
|
|
)
|
|
|
|
self.hidden_t_intermediate = self.decoder_cell._prepare_output(
|
|
model,
|
|
decoder_states,
|
|
)
|
|
|
|
if self.attention_type == AttentionType.Recurrent:
|
|
(
|
|
attention_weighted_encoder_context_t,
|
|
self.attention_weights_3d,
|
|
attention_blobs,
|
|
) = apply_recurrent_attention(
|
|
model=model,
|
|
encoder_output_dim=self.encoder_output_dim,
|
|
encoder_outputs_transposed=self.encoder_outputs_transposed,
|
|
weighted_encoder_outputs=self.weighted_encoder_outputs,
|
|
decoder_hidden_state_t=self.hidden_t_intermediate,
|
|
decoder_hidden_state_dim=self.decoder_state_dim,
|
|
scope=self.name,
|
|
attention_weighted_encoder_context_t_prev=(
|
|
attention_weighted_encoder_context_t_prev
|
|
),
|
|
encoder_lengths=self.encoder_lengths,
|
|
)
|
|
elif self.attention_type == AttentionType.Regular:
|
|
(
|
|
attention_weighted_encoder_context_t,
|
|
self.attention_weights_3d,
|
|
attention_blobs,
|
|
) = apply_regular_attention(
|
|
model=model,
|
|
encoder_output_dim=self.encoder_output_dim,
|
|
encoder_outputs_transposed=self.encoder_outputs_transposed,
|
|
weighted_encoder_outputs=self.weighted_encoder_outputs,
|
|
decoder_hidden_state_t=self.hidden_t_intermediate,
|
|
decoder_hidden_state_dim=self.decoder_state_dim,
|
|
scope=self.name,
|
|
encoder_lengths=self.encoder_lengths,
|
|
)
|
|
elif self.attention_type == AttentionType.Dot:
|
|
(
|
|
attention_weighted_encoder_context_t,
|
|
self.attention_weights_3d,
|
|
attention_blobs,
|
|
) = apply_dot_attention(
|
|
model=model,
|
|
encoder_output_dim=self.encoder_output_dim,
|
|
encoder_outputs_transposed=self.encoder_outputs_transposed,
|
|
decoder_hidden_state_t=self.hidden_t_intermediate,
|
|
decoder_hidden_state_dim=self.decoder_state_dim,
|
|
scope=self.name,
|
|
encoder_lengths=self.encoder_lengths,
|
|
)
|
|
elif self.attention_type == AttentionType.SoftCoverage:
|
|
(
|
|
attention_weighted_encoder_context_t,
|
|
self.attention_weights_3d,
|
|
attention_blobs,
|
|
coverage_t,
|
|
) = apply_soft_coverage_attention(
|
|
model=model,
|
|
encoder_output_dim=self.encoder_output_dim,
|
|
encoder_outputs_transposed=self.encoder_outputs_transposed,
|
|
weighted_encoder_outputs=self.weighted_encoder_outputs,
|
|
decoder_hidden_state_t=self.hidden_t_intermediate,
|
|
decoder_hidden_state_dim=self.decoder_state_dim,
|
|
scope=self.name,
|
|
encoder_lengths=self.encoder_lengths,
|
|
coverage_t_prev=coverage_t_prev,
|
|
coverage_weights=self.coverage_weights,
|
|
)
|
|
else:
|
|
raise Exception('Attention type {} not implemented'.format(
|
|
self.attention_type
|
|
))
|
|
|
|
if self.attention_memory_optimization:
|
|
self.recompute_blobs.extend(attention_blobs)
|
|
|
|
output = list(decoder_states) + [attention_weighted_encoder_context_t]
|
|
if self.attention_type == AttentionType.SoftCoverage:
|
|
output.append(coverage_t)
|
|
|
|
output[self.decoder_cell.get_output_state_index()] = model.Copy(
|
|
output[self.decoder_cell.get_output_state_index()],
|
|
self.scope('hidden_t_external'),
|
|
)
|
|
model.net.AddExternalOutputs(*output)
|
|
|
|
return output
|
|
|
|
def get_attention_weights(self):
|
|
# [batch_size, encoder_length, 1]
|
|
return self.attention_weights_3d
|
|
|
|
def prepare_input(self, model, input_blob):
|
|
if self.encoder_outputs_transposed is None:
|
|
self.encoder_outputs_transposed = brew.transpose(
|
|
model,
|
|
self.encoder_outputs,
|
|
self.scope('encoder_outputs_transposed'),
|
|
axes=[1, 2, 0],
|
|
)
|
|
if (
|
|
self.weighted_encoder_outputs is None and
|
|
self.attention_type != AttentionType.Dot
|
|
):
|
|
self.weighted_encoder_outputs = brew.fc(
|
|
model,
|
|
self.encoder_outputs,
|
|
self.scope('weighted_encoder_outputs'),
|
|
dim_in=self.encoder_output_dim,
|
|
dim_out=self.encoder_output_dim,
|
|
axis=2,
|
|
)
|
|
|
|
return self.decoder_cell.prepare_input(model, input_blob)
|
|
|
|
def build_initial_coverage(self, model):
|
|
"""
|
|
initial_coverage is always zeros of shape [encoder_length],
|
|
which shape must be determined programmatically dureing network
|
|
computation.
|
|
|
|
This method also sets self.coverage_weights, a separate transform
|
|
of encoder_outputs which is used to determine coverage contribution
|
|
tp attention.
|
|
"""
|
|
assert self.attention_type == AttentionType.SoftCoverage
|
|
|
|
# [encoder_length, batch_size, encoder_output_dim]
|
|
self.coverage_weights = brew.fc(
|
|
model,
|
|
self.encoder_outputs,
|
|
self.scope('coverage_weights'),
|
|
dim_in=self.encoder_output_dim,
|
|
dim_out=self.encoder_output_dim,
|
|
axis=2,
|
|
)
|
|
|
|
encoder_length = model.net.Slice(
|
|
model.net.Shape(self.encoder_outputs),
|
|
starts=[0],
|
|
ends=[1],
|
|
)
|
|
if (
|
|
scope.CurrentDeviceScope() is not None and
|
|
core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
|
|
):
|
|
encoder_length = model.net.CopyGPUToCPU(
|
|
encoder_length,
|
|
'encoder_length_cpu',
|
|
)
|
|
# total attention weight applied across decoding steps_per_checkpoint
|
|
# shape: [encoder_length]
|
|
initial_coverage = model.net.ConstantFill(
|
|
encoder_length,
|
|
self.scope('initial_coverage'),
|
|
value=0.0,
|
|
input_as_shape=1,
|
|
)
|
|
return initial_coverage
|
|
|
|
def get_state_names(self):
|
|
state_names = list(self.decoder_cell.get_state_names())
|
|
state_names[self.get_output_state_index()] = self.scope(
|
|
'hidden_t_external',
|
|
)
|
|
state_names.append(self.scope('attention_weighted_encoder_context_t'))
|
|
if self.attention_type == AttentionType.SoftCoverage:
|
|
state_names.append(self.scope('coverage_t'))
|
|
return state_names
|
|
|
|
def get_output_dim(self):
|
|
return self.decoder_state_dim + self.encoder_output_dim
|
|
|
|
def get_output_state_index(self):
|
|
return self.decoder_cell.get_output_state_index()
|
|
|
|
def _prepare_output(self, model, states):
|
|
if self.attention_type == AttentionType.SoftCoverage:
|
|
attention_context = states[-2]
|
|
else:
|
|
attention_context = states[-1]
|
|
|
|
with core.NameScope(self.name or ''):
|
|
output = brew.concat(
|
|
model,
|
|
[self.hidden_t_intermediate, attention_context],
|
|
'states_and_context_combination',
|
|
axis=2,
|
|
)
|
|
|
|
return output
|
|
|
|
def _prepare_output_sequence(self, model, state_outputs):
|
|
if self.attention_type == AttentionType.SoftCoverage:
|
|
decoder_state_outputs = state_outputs[:-4]
|
|
else:
|
|
decoder_state_outputs = state_outputs[:-2]
|
|
|
|
decoder_output = self.decoder_cell._prepare_output_sequence(
|
|
model,
|
|
decoder_state_outputs,
|
|
)
|
|
|
|
if self.attention_type == AttentionType.SoftCoverage:
|
|
attention_context_index = 2 * (len(self.get_state_names()) - 2)
|
|
else:
|
|
attention_context_index = 2 * (len(self.get_state_names()) - 1)
|
|
|
|
with core.NameScope(self.name or ''):
|
|
output = brew.concat(
|
|
model,
|
|
[
|
|
decoder_output,
|
|
state_outputs[attention_context_index],
|
|
],
|
|
'states_and_context_combination',
|
|
axis=2,
|
|
)
|
|
return output
|
|
|
|
|
|
class LSTMWithAttentionCell(AttentionCell):
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_output_dim,
|
|
encoder_outputs,
|
|
encoder_lengths,
|
|
decoder_input_dim,
|
|
decoder_state_dim,
|
|
name,
|
|
attention_type,
|
|
weighted_encoder_outputs,
|
|
forget_bias,
|
|
lstm_memory_optimization,
|
|
attention_memory_optimization,
|
|
forward_only=False,
|
|
):
|
|
decoder_cell = LSTMCell(
|
|
input_size=decoder_input_dim,
|
|
hidden_size=decoder_state_dim,
|
|
forget_bias=forget_bias,
|
|
memory_optimization=lstm_memory_optimization,
|
|
name='{}/decoder'.format(name),
|
|
forward_only=False,
|
|
drop_states=False,
|
|
)
|
|
super().__init__(
|
|
encoder_output_dim=encoder_output_dim,
|
|
encoder_outputs=encoder_outputs,
|
|
encoder_lengths=encoder_lengths,
|
|
decoder_cell=decoder_cell,
|
|
decoder_state_dim=decoder_state_dim,
|
|
name=name,
|
|
attention_type=attention_type,
|
|
weighted_encoder_outputs=weighted_encoder_outputs,
|
|
attention_memory_optimization=attention_memory_optimization,
|
|
forward_only=forward_only,
|
|
)
|
|
|
|
|
|
class MILSTMWithAttentionCell(AttentionCell):
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_output_dim,
|
|
encoder_outputs,
|
|
decoder_input_dim,
|
|
decoder_state_dim,
|
|
name,
|
|
attention_type,
|
|
weighted_encoder_outputs,
|
|
forget_bias,
|
|
lstm_memory_optimization,
|
|
attention_memory_optimization,
|
|
forward_only=False,
|
|
):
|
|
decoder_cell = MILSTMCell(
|
|
input_size=decoder_input_dim,
|
|
hidden_size=decoder_state_dim,
|
|
forget_bias=forget_bias,
|
|
memory_optimization=lstm_memory_optimization,
|
|
name='{}/decoder'.format(name),
|
|
forward_only=False,
|
|
drop_states=False,
|
|
)
|
|
super().__init__(
|
|
encoder_output_dim=encoder_output_dim,
|
|
encoder_outputs=encoder_outputs,
|
|
decoder_cell=decoder_cell,
|
|
decoder_state_dim=decoder_state_dim,
|
|
name=name,
|
|
attention_type=attention_type,
|
|
weighted_encoder_outputs=weighted_encoder_outputs,
|
|
attention_memory_optimization=attention_memory_optimization,
|
|
forward_only=forward_only,
|
|
)
|
|
|
|
|
|
def _LSTM(
|
|
cell_class,
|
|
model,
|
|
input_blob,
|
|
seq_lengths,
|
|
initial_states,
|
|
dim_in,
|
|
dim_out,
|
|
scope=None,
|
|
outputs_with_grads=(0,),
|
|
return_params=False,
|
|
memory_optimization=False,
|
|
forget_bias=0.0,
|
|
forward_only=False,
|
|
drop_states=False,
|
|
return_last_layer_only=True,
|
|
static_rnn_unroll_size=None,
|
|
**cell_kwargs
|
|
):
|
|
'''
|
|
Adds a standard LSTM recurrent network operator to a model.
|
|
|
|
cell_class: LSTMCell or compatible subclass
|
|
|
|
model: ModelHelper object new operators would be added to
|
|
|
|
input_blob: the input sequence in a format T x N x D
|
|
where T is sequence size, N - batch size and D - input dimension
|
|
|
|
seq_lengths: blob containing sequence lengths which would be passed to
|
|
LSTMUnit operator
|
|
|
|
initial_states: a list of (2 * num_layers) blobs representing the initial
|
|
hidden and cell states of each layer. If this argument is None,
|
|
these states will be added to the model as network parameters.
|
|
|
|
dim_in: input dimension
|
|
|
|
dim_out: number of units per LSTM layer
|
|
(use int for single-layer LSTM, list of ints for multi-layer)
|
|
|
|
outputs_with_grads : position indices of output blobs for LAST LAYER which
|
|
will receive external error gradient during backpropagation.
|
|
These outputs are: (h_all, h_last, c_all, c_last)
|
|
|
|
return_params: if True, will return a dictionary of parameters of the LSTM
|
|
|
|
memory_optimization: if enabled, the LSTM step is recomputed on backward
|
|
step so that we don't need to store forward activations for each
|
|
timestep. Saves memory with cost of computation.
|
|
|
|
forget_bias: forget gate bias (default 0.0)
|
|
|
|
forward_only: whether to create a backward pass
|
|
|
|
drop_states: drop invalid states, passed through to LSTMUnit operator
|
|
|
|
return_last_layer_only: only return outputs from final layer
|
|
(so that length of results does depend on number of layers)
|
|
|
|
static_rnn_unroll_size: if not None, we will use static RNN which is
|
|
unrolled into Caffe2 graph. The size of the unroll is the value of
|
|
this parameter.
|
|
'''
|
|
if type(dim_out) is not list and type(dim_out) is not tuple:
|
|
dim_out = [dim_out]
|
|
num_layers = len(dim_out)
|
|
|
|
cells = []
|
|
for i in range(num_layers):
|
|
cell = cell_class(
|
|
input_size=(dim_in if i == 0 else dim_out[i - 1]),
|
|
hidden_size=dim_out[i],
|
|
forget_bias=forget_bias,
|
|
memory_optimization=memory_optimization,
|
|
name=scope if num_layers == 1 else None,
|
|
forward_only=forward_only,
|
|
drop_states=drop_states,
|
|
**cell_kwargs
|
|
)
|
|
cells.append(cell)
|
|
|
|
cell = MultiRNNCell(
|
|
cells,
|
|
name=scope,
|
|
forward_only=forward_only,
|
|
) if num_layers > 1 else cells[0]
|
|
|
|
cell = (
|
|
cell if static_rnn_unroll_size is None
|
|
else UnrolledCell(cell, static_rnn_unroll_size))
|
|
|
|
# outputs_with_grads argument indexes into final layer
|
|
outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
|
|
_, result = cell.apply_over_sequence(
|
|
model=model,
|
|
inputs=input_blob,
|
|
seq_lengths=seq_lengths,
|
|
initial_states=initial_states,
|
|
outputs_with_grads=outputs_with_grads,
|
|
)
|
|
|
|
if return_last_layer_only:
|
|
result = result[4 * (num_layers - 1):]
|
|
if return_params:
|
|
result = list(result) + [{
|
|
'input': cell.get_input_params(),
|
|
'recurrent': cell.get_recurrent_params(),
|
|
}]
|
|
return tuple(result)
|
|
|
|
|
|
LSTM = functools.partial(_LSTM, LSTMCell)
|
|
BasicRNN = functools.partial(_LSTM, BasicRNNCell)
|
|
MILSTM = functools.partial(_LSTM, MILSTMCell)
|
|
LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
|
|
LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
|
|
|
|
|
|
class UnrolledCell(RNNCell):
|
|
def __init__(self, cell, T):
|
|
self.T = T
|
|
self.cell = cell
|
|
|
|
def apply_over_sequence(
|
|
self,
|
|
model,
|
|
inputs,
|
|
seq_lengths,
|
|
initial_states,
|
|
outputs_with_grads=None,
|
|
):
|
|
inputs = self.cell.prepare_input(model, inputs)
|
|
|
|
# Now they are blob references - outputs of splitting the input sequence
|
|
split_inputs = model.net.Split(
|
|
inputs,
|
|
[str(inputs) + "_timestep_{}".format(i)
|
|
for i in range(self.T)],
|
|
axis=0)
|
|
if self.T == 1:
|
|
split_inputs = [split_inputs]
|
|
|
|
states = initial_states
|
|
all_states = []
|
|
for t in range(0, self.T):
|
|
scope_name = "timestep_{}".format(t)
|
|
# Parameters of all timesteps are shared
|
|
with ParameterSharing({scope_name: ''}),\
|
|
scope.NameScope(scope_name):
|
|
timestep = model.param_init_net.ConstantFill(
|
|
[], "timestep", value=t, shape=[1],
|
|
dtype=core.DataType.INT32,
|
|
device_option=core.DeviceOption(caffe2_pb2.CPU))
|
|
states = self.cell._apply(
|
|
model=model,
|
|
input_t=split_inputs[t],
|
|
seq_lengths=seq_lengths,
|
|
states=states,
|
|
timestep=timestep,
|
|
)
|
|
all_states.append(states)
|
|
|
|
all_states = zip(*all_states)
|
|
all_states = [
|
|
model.net.Concat(
|
|
list(full_output),
|
|
[
|
|
str(full_output[0])[len("timestep_0/"):] + "_concat",
|
|
str(full_output[0])[len("timestep_0/"):] + "_concat_info"
|
|
|
|
],
|
|
axis=0)[0]
|
|
for full_output in all_states
|
|
]
|
|
# Interleave the state values similar to
|
|
#
|
|
# x = [1, 3, 5]
|
|
# y = [2, 4, 6]
|
|
# z = [val for pair in zip(x, y) for val in pair]
|
|
# # z is [1, 2, 3, 4, 5, 6]
|
|
#
|
|
# and returns it as outputs
|
|
outputs = tuple(
|
|
state for state_pair in zip(all_states, states) for state in state_pair
|
|
)
|
|
outputs_without_grad = set(range(len(outputs))) - set(
|
|
outputs_with_grads)
|
|
for i in outputs_without_grad:
|
|
model.net.ZeroGradient(outputs[i], [])
|
|
logging.debug("Added 0 gradients for blobs:",
|
|
[outputs[i] for i in outputs_without_grad])
|
|
|
|
final_output = self.cell._prepare_output_sequence(model, outputs)
|
|
|
|
return final_output, outputs
|
|
|
|
|
|
def GetLSTMParamNames():
|
|
weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
|
|
bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
|
|
return {'weights': weight_params, 'biases': bias_params}
|
|
|
|
|
|
def InitFromLSTMParams(lstm_pblobs, param_values):
|
|
'''
|
|
Set the parameters of LSTM based on predefined values
|
|
'''
|
|
weight_params = GetLSTMParamNames()['weights']
|
|
bias_params = GetLSTMParamNames()['biases']
|
|
for input_type in param_values.keys():
|
|
weight_values = [
|
|
param_values[input_type][w].flatten()
|
|
for w in weight_params
|
|
]
|
|
wmat = np.array([])
|
|
for w in weight_values:
|
|
wmat = np.append(wmat, w)
|
|
bias_values = [
|
|
param_values[input_type][b].flatten()
|
|
for b in bias_params
|
|
]
|
|
bm = np.array([])
|
|
for b in bias_values:
|
|
bm = np.append(bm, b)
|
|
|
|
weights_blob = lstm_pblobs[input_type]['weights']
|
|
bias_blob = lstm_pblobs[input_type]['biases']
|
|
cur_weight = workspace.FetchBlob(weights_blob)
|
|
cur_biases = workspace.FetchBlob(bias_blob)
|
|
|
|
workspace.FeedBlob(
|
|
weights_blob,
|
|
wmat.reshape(cur_weight.shape).astype(np.float32))
|
|
workspace.FeedBlob(
|
|
bias_blob,
|
|
bm.reshape(cur_biases.shape).astype(np.float32))
|
|
|
|
|
|
def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
|
|
scope, recurrent_params=None, input_params=None,
|
|
num_layers=1, return_params=False):
|
|
'''
|
|
CuDNN version of LSTM for GPUs.
|
|
input_blob Blob containing the input. Will need to be available
|
|
when param_init_net is run, because the sequence lengths
|
|
and batch sizes will be inferred from the size of this
|
|
blob.
|
|
initial_states tuple of (hidden_init, cell_init) blobs
|
|
dim_in input dimensions
|
|
dim_out output/hidden dimension
|
|
scope namescope to apply
|
|
recurrent_params dict of blobs containing values for recurrent
|
|
gate weights, biases (if None, use random init values)
|
|
See GetLSTMParamNames() for format.
|
|
input_params dict of blobs containing values for input
|
|
gate weights, biases (if None, use random init values)
|
|
See GetLSTMParamNames() for format.
|
|
num_layers number of LSTM layers
|
|
return_params if True, returns (param_extract_net, param_mapping)
|
|
where param_extract_net is a net that when run, will
|
|
populate the blobs specified in param_mapping with the
|
|
current gate weights and biases (input/recurrent).
|
|
Useful for assigning the values back to non-cuDNN
|
|
LSTM.
|
|
'''
|
|
with core.NameScope(scope):
|
|
weight_params = GetLSTMParamNames()['weights']
|
|
bias_params = GetLSTMParamNames()['biases']
|
|
|
|
input_weight_size = dim_out * dim_in
|
|
upper_layer_input_weight_size = dim_out * dim_out
|
|
recurrent_weight_size = dim_out * dim_out
|
|
input_bias_size = dim_out
|
|
recurrent_bias_size = dim_out
|
|
|
|
def init(layer, pname, input_type):
|
|
input_weight_size_for_layer = input_weight_size if layer == 0 else \
|
|
upper_layer_input_weight_size
|
|
if pname in weight_params:
|
|
sz = input_weight_size_for_layer if input_type == 'input' \
|
|
else recurrent_weight_size
|
|
elif pname in bias_params:
|
|
sz = input_bias_size if input_type == 'input' \
|
|
else recurrent_bias_size
|
|
else:
|
|
assert False, "unknown parameter type {}".format(pname)
|
|
return model.param_init_net.UniformFill(
|
|
[],
|
|
"lstm_init_{}_{}_{}".format(input_type, pname, layer),
|
|
shape=[sz])
|
|
|
|
# Multiply by 4 since we have 4 gates per LSTM unit
|
|
first_layer_sz = input_weight_size + recurrent_weight_size + \
|
|
input_bias_size + recurrent_bias_size
|
|
upper_layer_sz = upper_layer_input_weight_size + \
|
|
recurrent_weight_size + input_bias_size + \
|
|
recurrent_bias_size
|
|
total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
|
|
|
|
weights = model.create_param(
|
|
'lstm_weight',
|
|
shape=[total_sz],
|
|
initializer=Initializer('UniformFill'),
|
|
tags=ParameterTags.WEIGHT,
|
|
)
|
|
|
|
lstm_args = {
|
|
'hidden_size': dim_out,
|
|
'rnn_mode': 'lstm',
|
|
'bidirectional': 0, # TODO
|
|
'dropout': 1.0, # TODO
|
|
'input_mode': 'linear', # TODO
|
|
'num_layers': num_layers,
|
|
'engine': 'CUDNN'
|
|
}
|
|
|
|
param_extract_net = core.Net("lstm_param_extractor")
|
|
param_extract_net.AddExternalInputs([input_blob, weights])
|
|
param_extract_mapping = {}
|
|
|
|
# Populate the weights-blob from blobs containing parameters for
|
|
# the individual components of the LSTM, such as forget/input gate
|
|
# weights and bises. Also, create a special param_extract_net that
|
|
# can be used to grab those individual params from the black-box
|
|
# weights blob. These results can be then fed to InitFromLSTMParams()
|
|
for input_type in ['input', 'recurrent']:
|
|
param_extract_mapping[input_type] = {}
|
|
p = recurrent_params if input_type == 'recurrent' else input_params
|
|
if p is None:
|
|
p = {}
|
|
for pname in weight_params + bias_params:
|
|
for j in range(0, num_layers):
|
|
values = p[pname] if pname in p else init(j, pname, input_type)
|
|
model.param_init_net.RecurrentParamSet(
|
|
[input_blob, weights, values],
|
|
weights,
|
|
layer=j,
|
|
input_type=input_type,
|
|
param_type=pname,
|
|
**lstm_args
|
|
)
|
|
if pname not in param_extract_mapping[input_type]:
|
|
param_extract_mapping[input_type][pname] = {}
|
|
b = param_extract_net.RecurrentParamGet(
|
|
[input_blob, weights],
|
|
["lstm_{}_{}_{}".format(input_type, pname, j)],
|
|
layer=j,
|
|
input_type=input_type,
|
|
param_type=pname,
|
|
**lstm_args
|
|
)
|
|
param_extract_mapping[input_type][pname][j] = b
|
|
|
|
(hidden_input_blob, cell_input_blob) = initial_states
|
|
output, hidden_output, cell_output, rnn_scratch, dropout_states = \
|
|
model.net.Recurrent(
|
|
[input_blob, hidden_input_blob, cell_input_blob, weights],
|
|
["lstm_output", "lstm_hidden_output", "lstm_cell_output",
|
|
"lstm_rnn_scratch", "lstm_dropout_states"],
|
|
seed=random.randint(0, 100000), # TODO: dropout seed
|
|
**lstm_args
|
|
)
|
|
model.net.AddExternalOutputs(
|
|
hidden_output, cell_output, rnn_scratch, dropout_states)
|
|
|
|
if return_params:
|
|
param_extract = param_extract_net, param_extract_mapping
|
|
return output, hidden_output, cell_output, param_extract
|
|
else:
|
|
return output, hidden_output, cell_output
|
|
|
|
|
|
def LSTMWithAttention(
|
|
model,
|
|
decoder_inputs,
|
|
decoder_input_lengths,
|
|
initial_decoder_hidden_state,
|
|
initial_decoder_cell_state,
|
|
initial_attention_weighted_encoder_context,
|
|
encoder_output_dim,
|
|
encoder_outputs,
|
|
encoder_lengths,
|
|
decoder_input_dim,
|
|
decoder_state_dim,
|
|
scope,
|
|
attention_type=AttentionType.Regular,
|
|
outputs_with_grads=(0, 4),
|
|
weighted_encoder_outputs=None,
|
|
lstm_memory_optimization=False,
|
|
attention_memory_optimization=False,
|
|
forget_bias=0.0,
|
|
forward_only=False,
|
|
):
|
|
'''
|
|
Adds a LSTM with attention mechanism to a model.
|
|
|
|
The implementation is based on https://arxiv.org/abs/1409.0473, with
|
|
a small difference in the order
|
|
how we compute new attention context and new hidden state, similarly to
|
|
https://arxiv.org/abs/1508.04025.
|
|
|
|
The model uses encoder-decoder naming conventions,
|
|
where the decoder is the sequence the op is iterating over,
|
|
while computing the attention context over the encoder.
|
|
|
|
model: ModelHelper object new operators would be added to
|
|
|
|
decoder_inputs: the input sequence in a format T x N x D
|
|
where T is sequence size, N - batch size and D - input dimension
|
|
|
|
decoder_input_lengths: blob containing sequence lengths
|
|
which would be passed to LSTMUnit operator
|
|
|
|
initial_decoder_hidden_state: initial hidden state of LSTM
|
|
|
|
initial_decoder_cell_state: initial cell state of LSTM
|
|
|
|
initial_attention_weighted_encoder_context: initial attention context
|
|
|
|
encoder_output_dim: dimension of encoder outputs
|
|
|
|
encoder_outputs: the sequence, on which we compute the attention context
|
|
at every iteration
|
|
|
|
encoder_lengths: a tensor with lengths of each encoder sequence in batch
|
|
(may be None, meaning all encoder sequences are of same length)
|
|
|
|
decoder_input_dim: input dimension (last dimension on decoder_inputs)
|
|
|
|
decoder_state_dim: size of hidden states of LSTM
|
|
|
|
attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
|
|
Determines which type of attention mechanism to use.
|
|
|
|
outputs_with_grads : position indices of output blobs which will receive
|
|
external error gradient during backpropagation
|
|
|
|
weighted_encoder_outputs: encoder outputs to be used to compute attention
|
|
weights. In the basic case it's just linear transformation of
|
|
encoder outputs (that the default, when weighted_encoder_outputs is None).
|
|
However, it can be something more complicated - like a separate
|
|
encoder network (for example, in case of convolutional encoder)
|
|
|
|
lstm_memory_optimization: recompute LSTM activations on backward pass, so
|
|
we don't need to store their values in forward passes
|
|
|
|
attention_memory_optimization: recompute attention for backward pass
|
|
|
|
forward_only: whether to create only forward pass
|
|
'''
|
|
cell = LSTMWithAttentionCell(
|
|
encoder_output_dim=encoder_output_dim,
|
|
encoder_outputs=encoder_outputs,
|
|
encoder_lengths=encoder_lengths,
|
|
decoder_input_dim=decoder_input_dim,
|
|
decoder_state_dim=decoder_state_dim,
|
|
name=scope,
|
|
attention_type=attention_type,
|
|
weighted_encoder_outputs=weighted_encoder_outputs,
|
|
forget_bias=forget_bias,
|
|
lstm_memory_optimization=lstm_memory_optimization,
|
|
attention_memory_optimization=attention_memory_optimization,
|
|
forward_only=forward_only,
|
|
)
|
|
initial_states = [
|
|
initial_decoder_hidden_state,
|
|
initial_decoder_cell_state,
|
|
initial_attention_weighted_encoder_context,
|
|
]
|
|
if attention_type == AttentionType.SoftCoverage:
|
|
initial_states.append(cell.build_initial_coverage(model))
|
|
_, result = cell.apply_over_sequence(
|
|
model=model,
|
|
inputs=decoder_inputs,
|
|
seq_lengths=decoder_input_lengths,
|
|
initial_states=initial_states,
|
|
outputs_with_grads=outputs_with_grads,
|
|
)
|
|
return result
|
|
|
|
|
|
def _layered_LSTM(
|
|
model, input_blob, seq_lengths, initial_states,
|
|
dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
|
|
memory_optimization=False, forget_bias=0.0, forward_only=False,
|
|
drop_states=False, create_lstm=None):
|
|
params = locals() # leave it as a first line to grab all params
|
|
params.pop('create_lstm')
|
|
if not isinstance(dim_out, list):
|
|
return create_lstm(**params)
|
|
elif len(dim_out) == 1:
|
|
params['dim_out'] = dim_out[0]
|
|
return create_lstm(**params)
|
|
|
|
assert len(dim_out) != 0, "dim_out list can't be empty"
|
|
assert return_params is False, "return_params not supported for layering"
|
|
for i, output_dim in enumerate(dim_out):
|
|
params.update({
|
|
'dim_out': output_dim
|
|
})
|
|
output, last_output, all_states, last_state = create_lstm(**params)
|
|
params.update({
|
|
'input_blob': output,
|
|
'dim_in': output_dim,
|
|
'initial_states': (last_output, last_state),
|
|
'scope': scope + '_layer_{}'.format(i + 1)
|
|
})
|
|
return output, last_output, all_states, last_state
|
|
|
|
|
|
layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
|