mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Refactor rnn export (#7263)
* rnn refactor: extract rnn weights and biases * rnn refactor: make rnn with converted outputs * rnn refactor: finish it off
This commit is contained in:
@ -188,9 +188,9 @@ class Caffe2Backend(Backend):
|
||||
# the value is an attribute of this class that is a
|
||||
# function from ToffeIR node_def to caffe2 op_def
|
||||
_special_operators = {
|
||||
'LSTM': '_create_lstm',
|
||||
'GRU': '_create_gru',
|
||||
'RNN': '_create_rnn',
|
||||
'LSTM': '_create_rnn_variant',
|
||||
'GRU': '_create_rnn_variant',
|
||||
'RNN': '_create_rnn_variant',
|
||||
'Loop': '_create_loop',
|
||||
'If': '_create_if',
|
||||
}
|
||||
@ -295,386 +295,204 @@ class Caffe2Backend(Backend):
|
||||
return c2_op
|
||||
|
||||
@classmethod
|
||||
def _rnn_shape_inference(cls, init_model, pred_model, n, input_blob, W):
|
||||
for x in itertools.chain(init_model.graph.input,
|
||||
init_model.graph.value_info,
|
||||
pred_model.graph.input,
|
||||
pred_model.graph.value_info):
|
||||
if x.name == W:
|
||||
return x.type.tensor_type.shape.dim[1].dim_value
|
||||
def _rnn_reform_weights(cls, reforms, name, hidden_size, init_net, gates, reorder_indices):
|
||||
for name_from, name_to, do_concat, extra_dims in reforms:
|
||||
gate_blobs = ['%s/%s_%s' % (name, prefix, name_to) for prefix in gates]
|
||||
for i, x in enumerate(gate_blobs):
|
||||
dim0 = i * hidden_size, (i+1) * hidden_size
|
||||
starts, ends = zip(dim0, *extra_dims)
|
||||
init_net.Slice(name_from, x, starts=starts, ends=ends)
|
||||
if do_concat:
|
||||
reordered_gate_blobs = [gate_blobs[i] for i in reorder_indices]
|
||||
init_net.Concat(reordered_gate_blobs, ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
|
||||
|
||||
@classmethod
|
||||
def _create_rnn(cls, init_model, pred_model, n, opset_version):
|
||||
def _make_rnn_direction(cls, input_blob, B, W, R, initial_states_and_names, sequence_lens,
|
||||
pred_mh, init_net,
|
||||
input_size, hidden_size, num_gates, direction_offset,
|
||||
Bi, Br, W_, R_,
|
||||
reform, make_cell, keep_outputs):
|
||||
name = cls.dummy_name()
|
||||
|
||||
# input and recurrence biases are squashed together in onnx
|
||||
# but not in caffe2
|
||||
gates_hidden_size = num_gates * hidden_size
|
||||
bias_offset = 2 * direction_offset * gates_hidden_size
|
||||
weight_offset = direction_offset * gates_hidden_size
|
||||
Bi = init_net.Slice(B, name + Bi,
|
||||
starts=[bias_offset + 0 * gates_hidden_size],
|
||||
ends =[bias_offset + 1 * gates_hidden_size])
|
||||
Br = init_net.Slice(B, name + Br,
|
||||
starts=[bias_offset + 1 * gates_hidden_size],
|
||||
ends =[bias_offset + 2 * gates_hidden_size])
|
||||
W_ = init_net.Slice(W, name + W_,
|
||||
starts=[weight_offset + 0 * gates_hidden_size, 0],
|
||||
ends =[weight_offset + 1 * gates_hidden_size,-1])
|
||||
R_ = init_net.Slice(R, name + R_,
|
||||
starts=[weight_offset + 0 * gates_hidden_size, 0],
|
||||
ends =[weight_offset + 1 * gates_hidden_size,-1])
|
||||
|
||||
initial_states_sliced = []
|
||||
for initial_state, name_suffix in initial_states_and_names:
|
||||
initial_states_sliced.append(
|
||||
init_net.Slice(initial_state, name + name_suffix,
|
||||
starts=[direction_offset + 0, 0, 0],
|
||||
ends =[direction_offset + 1,-1,-1]))
|
||||
|
||||
if direction_offset == 1:
|
||||
if sequence_lens is not None:
|
||||
seq_lens_for_reverse = sequence_lens
|
||||
else:
|
||||
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
|
||||
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
|
||||
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
|
||||
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
|
||||
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
|
||||
seq_lens_for_reverse = dummy_sequence_lens
|
||||
|
||||
reform(Bi, Br, W_, R_, name, hidden_size, init_net)
|
||||
|
||||
if direction_offset == 1:
|
||||
input = pred_mh.net.ReversePackedSegs(
|
||||
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
|
||||
else:
|
||||
input = input_blob
|
||||
|
||||
outputs = keep_outputs(list(make_cell(
|
||||
pred_mh,
|
||||
input,
|
||||
sequence_lens,
|
||||
initial_states_sliced,
|
||||
input_size,
|
||||
hidden_size,
|
||||
name,
|
||||
drop_states=False,
|
||||
forward_only=True,
|
||||
)))
|
||||
|
||||
if direction_offset == 1:
|
||||
outputs[0] = pred_mh.net.ReversePackedSegs(
|
||||
[outputs[0], seq_lens_for_reverse], name + "/output-reversed")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@classmethod
|
||||
def _create_rnn_variant(cls, init_model, pred_model, n, opset_version):
|
||||
assert init_model is not None, "cannot convert RNNs without access to the full model"
|
||||
assert pred_model is not None, "cannot convert RNNs without access to the full model"
|
||||
|
||||
attrs = dict(n.attrs) # make a copy, which is safe to mutate
|
||||
hidden_size = attrs.pop('hidden_size')
|
||||
activation = force_unicode(attrs.pop('activations', ('tanh',))[0])
|
||||
direction = force_unicode(attrs.pop('direction', 'forward'))
|
||||
|
||||
if n.op_type == 'RNN':
|
||||
activation = force_unicode(attrs.pop('activations', ('tanh',))[0])
|
||||
elif n.op_type == 'GRU':
|
||||
linear_before_reset = attrs.pop('linear_before_reset', 0)
|
||||
|
||||
assert not attrs, "unsupported RNN attributes: " + str(attrs.keys())
|
||||
assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN"
|
||||
assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN/GRU/LSTM"
|
||||
|
||||
input_blob, W, R, B, sequence_lens, initial_h = n.inputs
|
||||
if n.op_type in ['RNN', 'GRU']:
|
||||
input_blob, W, R, B, sequence_lens, initial_h = n.inputs
|
||||
elif n.op_type == 'LSTM':
|
||||
input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
|
||||
|
||||
if sequence_lens == "":
|
||||
sequence_lens = None
|
||||
|
||||
input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
|
||||
if input_size is None:
|
||||
raise RuntimeError("best-effort shape inference for RNN input failed")
|
||||
for x in itertools.chain(init_model.graph.input,
|
||||
init_model.graph.value_info,
|
||||
pred_model.graph.input,
|
||||
pred_model.graph.value_info):
|
||||
if x.name == W:
|
||||
input_size = x.type.tensor_type.shape.dim[1].dim_value
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("best-effort shape inference for RNN/GRU/LSTM failed")
|
||||
|
||||
init_net = core.Net("init-net")
|
||||
pred_mh = ModelHelper()
|
||||
|
||||
def make_rnn(direction_offset):
|
||||
name = cls.dummy_name()
|
||||
if n.op_type == 'RNN':
|
||||
def reform(*args):
|
||||
pass
|
||||
|
||||
# input and recurrence biases are squashed together in
|
||||
# onnx but not in caffe2
|
||||
def make_cell(*args, **kwargs):
|
||||
return rnn_cell.BasicRNN(*args, activation=activation, **kwargs)
|
||||
|
||||
bias_offset = 2 * direction_offset * hidden_size
|
||||
init_net.Slice(B, name + "/i2h_b",
|
||||
starts=[bias_offset + 0 * hidden_size],
|
||||
ends =[bias_offset + 1 * hidden_size])
|
||||
init_net.Slice(B, name + "/gates_t_b",
|
||||
starts=[bias_offset + 1 * hidden_size],
|
||||
ends =[bias_offset + 2 * hidden_size])
|
||||
def make_rnn(direction_offset):
|
||||
return cls._make_rnn_direction(
|
||||
input_blob, B, W, R, [(initial_h, '/initial_h')], sequence_lens,
|
||||
pred_mh, init_net, input_size, hidden_size, 1, direction_offset,
|
||||
"/i2h_b", "/gates_t_b", "/i2h_w", "/gates_t_w",
|
||||
reform, make_cell, lambda x: x)
|
||||
|
||||
weight_offset = direction_offset * hidden_size
|
||||
init_net.Slice(W, name + '/i2h_w',
|
||||
starts=[weight_offset + 0 * hidden_size, 0],
|
||||
ends =[weight_offset + 1 * hidden_size,-1])
|
||||
init_net.Slice(R, name + '/gates_t_w',
|
||||
starts=[weight_offset + 0 * hidden_size, 0],
|
||||
ends =[weight_offset + 1 * hidden_size,-1])
|
||||
elif n.op_type == 'GRU':
|
||||
def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
|
||||
# caffe2 has a different order from onnx. We need to rearrange
|
||||
# z r h -> r z h
|
||||
reforms = ((W_, 'i2h_w', True, [(0,-1)]),
|
||||
(R_, 'gate_t_w', False, [(0,-1)]),
|
||||
(Bi, 'i2h_b', True, []),
|
||||
(Br, 'gate_t_b', False, []))
|
||||
cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
|
||||
['update', 'reset', 'output'], [1, 0, 2])
|
||||
|
||||
initial_h_sliced = name + '/initial_h'
|
||||
init_net.Slice(initial_h, initial_h_sliced,
|
||||
starts=[direction_offset + 0, 0, 0],
|
||||
ends =[direction_offset + 1,-1,-1])
|
||||
def make_cell(*args, **kwargs):
|
||||
return gru_cell.GRU(*args, linear_before_reset=linear_before_reset, **kwargs)
|
||||
|
||||
def make_rnn(direction_offset):
|
||||
return cls._make_rnn_direction(
|
||||
input_blob, B, W, R, [(initial_h, '/initial_h')], sequence_lens,
|
||||
pred_mh, init_net, input_size, hidden_size, 3, direction_offset,
|
||||
"_bias_i2h", "_bias_gates", "/i2h_w_pre", "/gates_t_w_pre",
|
||||
reform, make_cell, lambda x: x)
|
||||
|
||||
if direction_offset == 1:
|
||||
if sequence_lens is not None:
|
||||
seq_lens_for_reverse = sequence_lens
|
||||
else:
|
||||
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
|
||||
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
|
||||
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
|
||||
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
|
||||
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
|
||||
seq_lens_for_reverse = dummy_sequence_lens
|
||||
elif n.op_type == 'LSTM':
|
||||
def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
|
||||
# caffe2 has a different order from onnx. We need to rearrange
|
||||
# i o f c -> i f o c
|
||||
reforms = ((W_, 'i2h_w', True, [(0, -1)]),
|
||||
(R_, 'gates_t_w', True, [(0, -1)]),
|
||||
(Bi, 'i2h_b' , True, []),
|
||||
(Br, 'gates_t_b', True, []))
|
||||
cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
|
||||
['input', 'output', 'forget', 'cell'], [0, 2, 1, 3])
|
||||
|
||||
if direction_offset == 1:
|
||||
input = pred_mh.net.ReversePackedSegs(
|
||||
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
|
||||
else:
|
||||
input = input_blob
|
||||
def make_cell(*args, **kwargs):
|
||||
return rnn_cell.LSTM(*args, **kwargs)
|
||||
|
||||
hidden_t_all, hidden_t_last = rnn_cell.BasicRNN(
|
||||
pred_mh,
|
||||
input,
|
||||
sequence_lens,
|
||||
[initial_h_sliced],
|
||||
input_size,
|
||||
hidden_size,
|
||||
name,
|
||||
drop_states=False,
|
||||
forward_only=True,
|
||||
activation=activation
|
||||
)
|
||||
|
||||
if direction_offset == 1:
|
||||
hidden_t_all = pred_mh.net.ReversePackedSegs(
|
||||
[hidden_t_all, seq_lens_for_reverse], name + "/output-reversed")
|
||||
|
||||
return hidden_t_all, hidden_t_last
|
||||
def make_rnn(direction_offset):
|
||||
return cls._make_rnn_direction(
|
||||
input_blob, B, W, R, [(initial_h, '/initial_h'), (initial_c, '/initial_c')], sequence_lens,
|
||||
pred_mh, init_net, input_size, hidden_size, 4, direction_offset,
|
||||
"/i2h_b", "/gates_t_b", "/i2h_w", "/gates_t_w",
|
||||
reform, make_cell, lambda x: [x[0], x[1], x[3]])
|
||||
|
||||
if direction == 'forward':
|
||||
hidden_t_all, hidden_t_last = make_rnn(0)
|
||||
outputs = make_rnn(0)
|
||||
|
||||
# in the forward case, storage is shared between the two
|
||||
# outputs. We need to decouple them so that the
|
||||
# VariableLengthSequencePadding only mutates n.outputs[0]
|
||||
pred_mh.net.Copy(hidden_t_last, n.outputs[1])
|
||||
# in the forward case, storage is shared between the
|
||||
# last outputs. We need to decouple them so that the
|
||||
# VariableLengthSequencePadding only mutates
|
||||
# n.outputs[0]
|
||||
for i in range(1, len(outputs)):
|
||||
pred_mh.net.Copy(outputs[i], n.outputs[i])
|
||||
|
||||
pred_mh.net = pred_mh.net.Clone(
|
||||
"dummy-clone-net",
|
||||
blob_remap={ hidden_t_all: n.outputs[0] }
|
||||
"dummy-clone-net", blob_remap={ outputs[0]: n.outputs[0] }
|
||||
)
|
||||
elif direction == 'bidirectional':
|
||||
hidden_t_all_f, hidden_t_last_f = make_rnn(0)
|
||||
hidden_t_all_b, hidden_t_last_b = make_rnn(1)
|
||||
pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
|
||||
outputs_f = make_rnn(0)
|
||||
outputs_b = make_rnn(1)
|
||||
|
||||
pred_mh.net.Concat([outputs_f[0], outputs_b[0]],
|
||||
[n.outputs[0], cls.dummy_name()], axis=2)
|
||||
pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
|
||||
[n.outputs[1], cls.dummy_name()], axis=0)
|
||||
|
||||
if sequence_lens is not None:
|
||||
pred_mh.net.VariableLengthSequencePadding(
|
||||
[n.outputs[0], sequence_lens], [n.outputs[0]])
|
||||
|
||||
return Caffe2Ops(list(pred_mh.Proto().op),
|
||||
list(init_net.Proto().op),
|
||||
list(pred_mh.Proto().external_input))
|
||||
|
||||
@classmethod
|
||||
def _create_lstm(cls, init_model, pred_model, n, opset_version):
|
||||
assert init_model is not None, "cannot convert LSTMs without access to the full model"
|
||||
assert pred_model is not None, "cannot convert LSTMs without access to the full model"
|
||||
|
||||
attrs = dict(n.attrs) # make a copy, which is safe to mutate
|
||||
hidden_size = attrs.pop('hidden_size')
|
||||
direction = force_unicode(attrs.pop('direction', 'forward'))
|
||||
assert not attrs, "unsupported LSTM attributes: " + str(attrs.keys())
|
||||
assert direction in ['forward', 'bidirectional'], "unsupported backwards LSTM"
|
||||
|
||||
input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
|
||||
|
||||
if sequence_lens == "":
|
||||
sequence_lens = None
|
||||
|
||||
input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
|
||||
if input_size is None:
|
||||
raise RuntimeError("best-effort shape inference for LSTM input failed")
|
||||
|
||||
init_net = core.Net("init-net")
|
||||
pred_mh = ModelHelper()
|
||||
|
||||
def make_lstm(direction_offset):
|
||||
name = cls.dummy_name()
|
||||
|
||||
# input and recurrence biases are squashed together in
|
||||
# onnx but not in caffe2
|
||||
|
||||
bias_offset = 8 * direction_offset * hidden_size
|
||||
Bi = init_net.Slice(B, name + "_bias_i2h",
|
||||
starts=[bias_offset + 0 * hidden_size],
|
||||
ends =[bias_offset + 4 * hidden_size])
|
||||
Br = init_net.Slice(B, name + "_bias_gates",
|
||||
starts=[bias_offset + 4 * hidden_size],
|
||||
ends =[bias_offset + 8 * hidden_size])
|
||||
|
||||
weight_offset = 4 * direction_offset * hidden_size
|
||||
W_ = init_net.Slice(W, name + '/i2h_w_pre',
|
||||
starts=[weight_offset + 0 * hidden_size, 0],
|
||||
ends =[weight_offset + 4 * hidden_size,-1])
|
||||
R_ = init_net.Slice(R, name + '/gates_t_w_pre',
|
||||
starts=[weight_offset + 0 * hidden_size, 0],
|
||||
ends =[weight_offset + 4 * hidden_size,-1])
|
||||
|
||||
# caffe2 has a different order from onnx. We need to rearrange
|
||||
# i o f c -> i f o c
|
||||
reforms = ((W_, 'i2h_w', [(0, -1)]),
|
||||
(R_, 'gates_t_w', [(0, -1)]),
|
||||
(Bi, 'i2h_b' , []),
|
||||
(Br, 'gates_t_b', []))
|
||||
for name_from, name_to, extra_dims in reforms:
|
||||
xi, xo, xf, xc = [name_from + suffix for suffix in ("_i", "_o", "_f", "_c")]
|
||||
for i, x in enumerate([xi, xo, xf, xc]):
|
||||
dim0 = i * hidden_size, (i+1) * hidden_size
|
||||
starts, ends = zip(dim0, *extra_dims)
|
||||
init_net.Slice(name_from, x, starts=starts, ends=ends)
|
||||
init_net.Concat([xi, xf, xo, xc], ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
|
||||
|
||||
initial_h_sliced = name + '/initial_h'
|
||||
init_net.Slice(initial_h, initial_h_sliced,
|
||||
starts=[direction_offset + 0, 0, 0],
|
||||
ends =[direction_offset + 1,-1,-1])
|
||||
initial_c_sliced = name + '/initial_c'
|
||||
init_net.Slice(initial_c, initial_c_sliced,
|
||||
starts=[direction_offset + 0, 0, 0],
|
||||
ends =[direction_offset + 1,-1,-1])
|
||||
|
||||
if direction_offset == 1:
|
||||
if sequence_lens is not None:
|
||||
seq_lens_for_reverse = sequence_lens
|
||||
else:
|
||||
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
|
||||
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
|
||||
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
|
||||
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
|
||||
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
|
||||
seq_lens_for_reverse = dummy_sequence_lens
|
||||
|
||||
if direction_offset == 1:
|
||||
input = pred_mh.net.ReversePackedSegs(
|
||||
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
|
||||
else:
|
||||
input = input_blob
|
||||
|
||||
hidden_t_all, hidden_t_last, _, cell_last, params = rnn_cell.LSTM(
|
||||
pred_mh,
|
||||
input,
|
||||
sequence_lens,
|
||||
[initial_h_sliced, initial_c_sliced],
|
||||
input_size,
|
||||
hidden_size,
|
||||
name,
|
||||
drop_states=False,
|
||||
forward_only=True,
|
||||
return_params=True
|
||||
)
|
||||
|
||||
if direction_offset == 1:
|
||||
hidden_t_all = pred_mh.net.ReversePackedSegs(
|
||||
[hidden_t_all, seq_lens_for_reverse], name + "/output-reversed")
|
||||
|
||||
return hidden_t_all, hidden_t_last, cell_last
|
||||
|
||||
if direction == 'forward':
|
||||
hidden_t_all, hidden_t_last, cell_last = make_lstm(0)
|
||||
|
||||
# in the forward case, storage is shared between the three
|
||||
# outputs. We need to decouple them so that the
|
||||
# VariableLengthSequencePadding only mutates n.outputs[0]
|
||||
pred_mh.net.Copy(hidden_t_last, n.outputs[1])
|
||||
pred_mh.net.Copy(cell_last, n.outputs[2])
|
||||
|
||||
pred_mh.net = pred_mh.net.Clone(
|
||||
"dummy-clone-net",
|
||||
blob_remap={ hidden_t_all: n.outputs[0] }
|
||||
)
|
||||
elif direction == 'bidirectional':
|
||||
hidden_t_all_f, hidden_t_last_f, cell_last_f = make_lstm(0)
|
||||
hidden_t_all_b, hidden_t_last_b, cell_last_b = make_lstm(1)
|
||||
pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
|
||||
[n.outputs[0], cls.dummy_name()], axis=2)
|
||||
pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
|
||||
[n.outputs[1], cls.dummy_name()], axis=0)
|
||||
pred_mh.net.Concat([cell_last_f, cell_last_b],
|
||||
[n.outputs[2], cls.dummy_name()], axis=0)
|
||||
|
||||
if sequence_lens is not None:
|
||||
pred_mh.net.VariableLengthSequencePadding(
|
||||
[n.outputs[0], sequence_lens], [n.outputs[0]])
|
||||
|
||||
return Caffe2Ops(list(pred_mh.Proto().op),
|
||||
list(init_net.Proto().op),
|
||||
list(pred_mh.Proto().external_input))
|
||||
|
||||
@classmethod
|
||||
def _create_gru(cls, init_model, pred_model, n, opset_version):
|
||||
assert init_model is not None, "cannot convert GRUs without access to the full model"
|
||||
assert pred_model is not None, "cannot convert GRUs without access to the full model"
|
||||
|
||||
attrs = dict(n.attrs) # make a copy, which is safe to mutate
|
||||
hidden_size = attrs.pop('hidden_size')
|
||||
linear_before_reset = attrs.pop('linear_before_reset', 0)
|
||||
direction = force_unicode(attrs.pop('direction', 'forward'))
|
||||
assert not attrs, "unsupported GRU attributes: " + str(attrs.keys())
|
||||
assert direction in ['forward', 'bidirectional'], "unsupported backwards GRU"
|
||||
|
||||
input_blob, W, R, B, sequence_lens, initial_h = n.inputs
|
||||
|
||||
if sequence_lens == "":
|
||||
sequence_lens = None
|
||||
|
||||
input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
|
||||
if input_size is None:
|
||||
raise RuntimeError("best-effort shape inference for GRU input failed")
|
||||
|
||||
init_net = core.Net("init-net")
|
||||
pred_mh = ModelHelper()
|
||||
|
||||
def make_gru(direction_offset):
|
||||
name = cls.dummy_name()
|
||||
|
||||
# input and recurrence biases are squashed together in
|
||||
# onnx but not in caffe2
|
||||
|
||||
bias_offset = 6 * direction_offset * hidden_size
|
||||
Bi = init_net.Slice(B, name + "_bias_i2h",
|
||||
starts=[bias_offset + 0 * hidden_size],
|
||||
ends =[bias_offset + 3 * hidden_size])
|
||||
Br = init_net.Slice(B, name + "_bias_gates",
|
||||
starts=[bias_offset + 3 * hidden_size],
|
||||
ends =[bias_offset + 6 * hidden_size])
|
||||
|
||||
weight_offset = 3 * direction_offset * hidden_size
|
||||
W_ = init_net.Slice(W, name + '/i2h_w_pre',
|
||||
starts=[weight_offset + 0 * hidden_size, 0],
|
||||
ends =[weight_offset + 3 * hidden_size,-1])
|
||||
R_ = init_net.Slice(R, name + '/gates_t_w_pre',
|
||||
starts=[weight_offset + 0 * hidden_size, 0],
|
||||
ends =[weight_offset + 3 * hidden_size,-1])
|
||||
|
||||
# caffe2 has a different order from onnx. We need to rearrange
|
||||
# z r h -> r z h
|
||||
reforms = ((W_, 'i2h_w', True, [(0,-1)]),
|
||||
(R_, 'gate_t_w', False, [(0,-1)]),
|
||||
(Bi, 'i2h_b', True, []),
|
||||
(Br, 'gate_t_b', False, []))
|
||||
for name_from, name_to, do_concat, extra_dims in reforms:
|
||||
xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')]
|
||||
for i, x in enumerate([xz, xr, xh]):
|
||||
dim0 = i * hidden_size, (i+1) * hidden_size
|
||||
starts, ends = zip(dim0, *extra_dims)
|
||||
init_net.Slice(name_from, x, starts=starts, ends=ends)
|
||||
if do_concat:
|
||||
init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
|
||||
|
||||
initial_h_sliced = name + '/initial_h'
|
||||
init_net.Slice(initial_h, initial_h_sliced,
|
||||
starts=[direction_offset + 0, 0, 0],
|
||||
ends =[direction_offset + 1,-1,-1])
|
||||
|
||||
if direction_offset == 1:
|
||||
if sequence_lens is not None:
|
||||
seq_lens_for_reverse = sequence_lens
|
||||
else:
|
||||
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
|
||||
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
|
||||
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
|
||||
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
|
||||
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
|
||||
seq_lens_for_reverse = dummy_sequence_lens
|
||||
|
||||
if direction_offset == 1:
|
||||
input = pred_mh.net.ReversePackedSegs(
|
||||
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
|
||||
else:
|
||||
input = input_blob
|
||||
|
||||
hidden_t_all, hidden_t_last = gru_cell.GRU(
|
||||
pred_mh,
|
||||
input,
|
||||
sequence_lens,
|
||||
[initial_h_sliced],
|
||||
input_size,
|
||||
hidden_size,
|
||||
name,
|
||||
drop_states=False,
|
||||
forward_only=True,
|
||||
linear_before_reset=linear_before_reset
|
||||
)
|
||||
|
||||
if direction_offset == 1:
|
||||
hidden_t_all = pred_mh.net.ReversePackedSegs(
|
||||
[hidden_t_all, seq_lens_for_reverse], name + "/output-reversed")
|
||||
|
||||
return hidden_t_all, hidden_t_last
|
||||
|
||||
if direction == 'forward':
|
||||
hidden_t_all, hidden_t_last = make_gru(0)
|
||||
|
||||
# in the forward case, storage is shared between the two
|
||||
# outputs. We need to decouple them so that the
|
||||
# VariableLengthSequencePadding only mutates n.outputs[0]
|
||||
pred_mh.net.Copy(hidden_t_last, n.outputs[1])
|
||||
|
||||
pred_mh.net = pred_mh.net.Clone(
|
||||
"dummy-clone-net",
|
||||
blob_remap={ hidden_t_all: n.outputs[0] }
|
||||
)
|
||||
elif direction == 'bidirectional':
|
||||
hidden_t_all_f, hidden_t_last_f = make_gru(0)
|
||||
hidden_t_all_b, hidden_t_last_b = make_gru(1)
|
||||
pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
|
||||
[n.outputs[0], cls.dummy_name()], axis=2)
|
||||
pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
|
||||
[n.outputs[1], cls.dummy_name()], axis=0)
|
||||
for i in range(1, len(n.outputs)):
|
||||
pred_mh.net.Concat([outputs_f[i], outputs_b[i]],
|
||||
[n.outputs[i], cls.dummy_name()], axis=0)
|
||||
|
||||
if sequence_lens is not None:
|
||||
pred_mh.net.VariableLengthSequencePadding(
|
||||
|
||||
Reference in New Issue
Block a user