Refactor rnn export (#7263)

* rnn refactor: extract rnn weights and biases

* rnn refactor: make rnn with converted outputs

* rnn refactor: finish it off
This commit is contained in:
anderspapitto
2018-05-04 14:00:09 -07:00
committed by GitHub
parent 55b8317f1d
commit 67a9948d87

View File

@ -188,9 +188,9 @@ class Caffe2Backend(Backend):
# the value is an attribute of this class that is a
# function from ToffeIR node_def to caffe2 op_def
_special_operators = {
'LSTM': '_create_lstm',
'GRU': '_create_gru',
'RNN': '_create_rnn',
'LSTM': '_create_rnn_variant',
'GRU': '_create_rnn_variant',
'RNN': '_create_rnn_variant',
'Loop': '_create_loop',
'If': '_create_if',
}
@ -295,386 +295,204 @@ class Caffe2Backend(Backend):
return c2_op
@classmethod
def _rnn_shape_inference(cls, init_model, pred_model, n, input_blob, W):
for x in itertools.chain(init_model.graph.input,
init_model.graph.value_info,
pred_model.graph.input,
pred_model.graph.value_info):
if x.name == W:
return x.type.tensor_type.shape.dim[1].dim_value
def _rnn_reform_weights(cls, reforms, name, hidden_size, init_net, gates, reorder_indices):
for name_from, name_to, do_concat, extra_dims in reforms:
gate_blobs = ['%s/%s_%s' % (name, prefix, name_to) for prefix in gates]
for i, x in enumerate(gate_blobs):
dim0 = i * hidden_size, (i+1) * hidden_size
starts, ends = zip(dim0, *extra_dims)
init_net.Slice(name_from, x, starts=starts, ends=ends)
if do_concat:
reordered_gate_blobs = [gate_blobs[i] for i in reorder_indices]
init_net.Concat(reordered_gate_blobs, ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
@classmethod
def _create_rnn(cls, init_model, pred_model, n, opset_version):
def _make_rnn_direction(cls, input_blob, B, W, R, initial_states_and_names, sequence_lens,
pred_mh, init_net,
input_size, hidden_size, num_gates, direction_offset,
Bi, Br, W_, R_,
reform, make_cell, keep_outputs):
name = cls.dummy_name()
# input and recurrence biases are squashed together in onnx
# but not in caffe2
gates_hidden_size = num_gates * hidden_size
bias_offset = 2 * direction_offset * gates_hidden_size
weight_offset = direction_offset * gates_hidden_size
Bi = init_net.Slice(B, name + Bi,
starts=[bias_offset + 0 * gates_hidden_size],
ends =[bias_offset + 1 * gates_hidden_size])
Br = init_net.Slice(B, name + Br,
starts=[bias_offset + 1 * gates_hidden_size],
ends =[bias_offset + 2 * gates_hidden_size])
W_ = init_net.Slice(W, name + W_,
starts=[weight_offset + 0 * gates_hidden_size, 0],
ends =[weight_offset + 1 * gates_hidden_size,-1])
R_ = init_net.Slice(R, name + R_,
starts=[weight_offset + 0 * gates_hidden_size, 0],
ends =[weight_offset + 1 * gates_hidden_size,-1])
initial_states_sliced = []
for initial_state, name_suffix in initial_states_and_names:
initial_states_sliced.append(
init_net.Slice(initial_state, name + name_suffix,
starts=[direction_offset + 0, 0, 0],
ends =[direction_offset + 1,-1,-1]))
if direction_offset == 1:
if sequence_lens is not None:
seq_lens_for_reverse = sequence_lens
else:
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
seq_lens_for_reverse = dummy_sequence_lens
reform(Bi, Br, W_, R_, name, hidden_size, init_net)
if direction_offset == 1:
input = pred_mh.net.ReversePackedSegs(
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
else:
input = input_blob
outputs = keep_outputs(list(make_cell(
pred_mh,
input,
sequence_lens,
initial_states_sliced,
input_size,
hidden_size,
name,
drop_states=False,
forward_only=True,
)))
if direction_offset == 1:
outputs[0] = pred_mh.net.ReversePackedSegs(
[outputs[0], seq_lens_for_reverse], name + "/output-reversed")
return outputs
@classmethod
def _create_rnn_variant(cls, init_model, pred_model, n, opset_version):
assert init_model is not None, "cannot convert RNNs without access to the full model"
assert pred_model is not None, "cannot convert RNNs without access to the full model"
attrs = dict(n.attrs) # make a copy, which is safe to mutate
hidden_size = attrs.pop('hidden_size')
activation = force_unicode(attrs.pop('activations', ('tanh',))[0])
direction = force_unicode(attrs.pop('direction', 'forward'))
if n.op_type == 'RNN':
activation = force_unicode(attrs.pop('activations', ('tanh',))[0])
elif n.op_type == 'GRU':
linear_before_reset = attrs.pop('linear_before_reset', 0)
assert not attrs, "unsupported RNN attributes: " + str(attrs.keys())
assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN"
assert direction in ['forward', 'bidirectional'], "unsupported backwards RNN/GRU/LSTM"
input_blob, W, R, B, sequence_lens, initial_h = n.inputs
if n.op_type in ['RNN', 'GRU']:
input_blob, W, R, B, sequence_lens, initial_h = n.inputs
elif n.op_type == 'LSTM':
input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
if sequence_lens == "":
sequence_lens = None
input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
if input_size is None:
raise RuntimeError("best-effort shape inference for RNN input failed")
for x in itertools.chain(init_model.graph.input,
init_model.graph.value_info,
pred_model.graph.input,
pred_model.graph.value_info):
if x.name == W:
input_size = x.type.tensor_type.shape.dim[1].dim_value
break
else:
raise RuntimeError("best-effort shape inference for RNN/GRU/LSTM failed")
init_net = core.Net("init-net")
pred_mh = ModelHelper()
def make_rnn(direction_offset):
name = cls.dummy_name()
if n.op_type == 'RNN':
def reform(*args):
pass
# input and recurrence biases are squashed together in
# onnx but not in caffe2
def make_cell(*args, **kwargs):
return rnn_cell.BasicRNN(*args, activation=activation, **kwargs)
bias_offset = 2 * direction_offset * hidden_size
init_net.Slice(B, name + "/i2h_b",
starts=[bias_offset + 0 * hidden_size],
ends =[bias_offset + 1 * hidden_size])
init_net.Slice(B, name + "/gates_t_b",
starts=[bias_offset + 1 * hidden_size],
ends =[bias_offset + 2 * hidden_size])
def make_rnn(direction_offset):
return cls._make_rnn_direction(
input_blob, B, W, R, [(initial_h, '/initial_h')], sequence_lens,
pred_mh, init_net, input_size, hidden_size, 1, direction_offset,
"/i2h_b", "/gates_t_b", "/i2h_w", "/gates_t_w",
reform, make_cell, lambda x: x)
weight_offset = direction_offset * hidden_size
init_net.Slice(W, name + '/i2h_w',
starts=[weight_offset + 0 * hidden_size, 0],
ends =[weight_offset + 1 * hidden_size,-1])
init_net.Slice(R, name + '/gates_t_w',
starts=[weight_offset + 0 * hidden_size, 0],
ends =[weight_offset + 1 * hidden_size,-1])
elif n.op_type == 'GRU':
def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
# caffe2 has a different order from onnx. We need to rearrange
# z r h -> r z h
reforms = ((W_, 'i2h_w', True, [(0,-1)]),
(R_, 'gate_t_w', False, [(0,-1)]),
(Bi, 'i2h_b', True, []),
(Br, 'gate_t_b', False, []))
cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
['update', 'reset', 'output'], [1, 0, 2])
initial_h_sliced = name + '/initial_h'
init_net.Slice(initial_h, initial_h_sliced,
starts=[direction_offset + 0, 0, 0],
ends =[direction_offset + 1,-1,-1])
def make_cell(*args, **kwargs):
return gru_cell.GRU(*args, linear_before_reset=linear_before_reset, **kwargs)
def make_rnn(direction_offset):
return cls._make_rnn_direction(
input_blob, B, W, R, [(initial_h, '/initial_h')], sequence_lens,
pred_mh, init_net, input_size, hidden_size, 3, direction_offset,
"_bias_i2h", "_bias_gates", "/i2h_w_pre", "/gates_t_w_pre",
reform, make_cell, lambda x: x)
if direction_offset == 1:
if sequence_lens is not None:
seq_lens_for_reverse = sequence_lens
else:
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
seq_lens_for_reverse = dummy_sequence_lens
elif n.op_type == 'LSTM':
def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
# caffe2 has a different order from onnx. We need to rearrange
# i o f c -> i f o c
reforms = ((W_, 'i2h_w', True, [(0, -1)]),
(R_, 'gates_t_w', True, [(0, -1)]),
(Bi, 'i2h_b' , True, []),
(Br, 'gates_t_b', True, []))
cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
['input', 'output', 'forget', 'cell'], [0, 2, 1, 3])
if direction_offset == 1:
input = pred_mh.net.ReversePackedSegs(
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
else:
input = input_blob
def make_cell(*args, **kwargs):
return rnn_cell.LSTM(*args, **kwargs)
hidden_t_all, hidden_t_last = rnn_cell.BasicRNN(
pred_mh,
input,
sequence_lens,
[initial_h_sliced],
input_size,
hidden_size,
name,
drop_states=False,
forward_only=True,
activation=activation
)
if direction_offset == 1:
hidden_t_all = pred_mh.net.ReversePackedSegs(
[hidden_t_all, seq_lens_for_reverse], name + "/output-reversed")
return hidden_t_all, hidden_t_last
def make_rnn(direction_offset):
return cls._make_rnn_direction(
input_blob, B, W, R, [(initial_h, '/initial_h'), (initial_c, '/initial_c')], sequence_lens,
pred_mh, init_net, input_size, hidden_size, 4, direction_offset,
"/i2h_b", "/gates_t_b", "/i2h_w", "/gates_t_w",
reform, make_cell, lambda x: [x[0], x[1], x[3]])
if direction == 'forward':
hidden_t_all, hidden_t_last = make_rnn(0)
outputs = make_rnn(0)
# in the forward case, storage is shared between the two
# outputs. We need to decouple them so that the
# VariableLengthSequencePadding only mutates n.outputs[0]
pred_mh.net.Copy(hidden_t_last, n.outputs[1])
# in the forward case, storage is shared between the
# last outputs. We need to decouple them so that the
# VariableLengthSequencePadding only mutates
# n.outputs[0]
for i in range(1, len(outputs)):
pred_mh.net.Copy(outputs[i], n.outputs[i])
pred_mh.net = pred_mh.net.Clone(
"dummy-clone-net",
blob_remap={ hidden_t_all: n.outputs[0] }
"dummy-clone-net", blob_remap={ outputs[0]: n.outputs[0] }
)
elif direction == 'bidirectional':
hidden_t_all_f, hidden_t_last_f = make_rnn(0)
hidden_t_all_b, hidden_t_last_b = make_rnn(1)
pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
outputs_f = make_rnn(0)
outputs_b = make_rnn(1)
pred_mh.net.Concat([outputs_f[0], outputs_b[0]],
[n.outputs[0], cls.dummy_name()], axis=2)
pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
[n.outputs[1], cls.dummy_name()], axis=0)
if sequence_lens is not None:
pred_mh.net.VariableLengthSequencePadding(
[n.outputs[0], sequence_lens], [n.outputs[0]])
return Caffe2Ops(list(pred_mh.Proto().op),
list(init_net.Proto().op),
list(pred_mh.Proto().external_input))
@classmethod
def _create_lstm(cls, init_model, pred_model, n, opset_version):
assert init_model is not None, "cannot convert LSTMs without access to the full model"
assert pred_model is not None, "cannot convert LSTMs without access to the full model"
attrs = dict(n.attrs) # make a copy, which is safe to mutate
hidden_size = attrs.pop('hidden_size')
direction = force_unicode(attrs.pop('direction', 'forward'))
assert not attrs, "unsupported LSTM attributes: " + str(attrs.keys())
assert direction in ['forward', 'bidirectional'], "unsupported backwards LSTM"
input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
if sequence_lens == "":
sequence_lens = None
input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
if input_size is None:
raise RuntimeError("best-effort shape inference for LSTM input failed")
init_net = core.Net("init-net")
pred_mh = ModelHelper()
def make_lstm(direction_offset):
name = cls.dummy_name()
# input and recurrence biases are squashed together in
# onnx but not in caffe2
bias_offset = 8 * direction_offset * hidden_size
Bi = init_net.Slice(B, name + "_bias_i2h",
starts=[bias_offset + 0 * hidden_size],
ends =[bias_offset + 4 * hidden_size])
Br = init_net.Slice(B, name + "_bias_gates",
starts=[bias_offset + 4 * hidden_size],
ends =[bias_offset + 8 * hidden_size])
weight_offset = 4 * direction_offset * hidden_size
W_ = init_net.Slice(W, name + '/i2h_w_pre',
starts=[weight_offset + 0 * hidden_size, 0],
ends =[weight_offset + 4 * hidden_size,-1])
R_ = init_net.Slice(R, name + '/gates_t_w_pre',
starts=[weight_offset + 0 * hidden_size, 0],
ends =[weight_offset + 4 * hidden_size,-1])
# caffe2 has a different order from onnx. We need to rearrange
# i o f c -> i f o c
reforms = ((W_, 'i2h_w', [(0, -1)]),
(R_, 'gates_t_w', [(0, -1)]),
(Bi, 'i2h_b' , []),
(Br, 'gates_t_b', []))
for name_from, name_to, extra_dims in reforms:
xi, xo, xf, xc = [name_from + suffix for suffix in ("_i", "_o", "_f", "_c")]
for i, x in enumerate([xi, xo, xf, xc]):
dim0 = i * hidden_size, (i+1) * hidden_size
starts, ends = zip(dim0, *extra_dims)
init_net.Slice(name_from, x, starts=starts, ends=ends)
init_net.Concat([xi, xf, xo, xc], ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
initial_h_sliced = name + '/initial_h'
init_net.Slice(initial_h, initial_h_sliced,
starts=[direction_offset + 0, 0, 0],
ends =[direction_offset + 1,-1,-1])
initial_c_sliced = name + '/initial_c'
init_net.Slice(initial_c, initial_c_sliced,
starts=[direction_offset + 0, 0, 0],
ends =[direction_offset + 1,-1,-1])
if direction_offset == 1:
if sequence_lens is not None:
seq_lens_for_reverse = sequence_lens
else:
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
seq_lens_for_reverse = dummy_sequence_lens
if direction_offset == 1:
input = pred_mh.net.ReversePackedSegs(
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
else:
input = input_blob
hidden_t_all, hidden_t_last, _, cell_last, params = rnn_cell.LSTM(
pred_mh,
input,
sequence_lens,
[initial_h_sliced, initial_c_sliced],
input_size,
hidden_size,
name,
drop_states=False,
forward_only=True,
return_params=True
)
if direction_offset == 1:
hidden_t_all = pred_mh.net.ReversePackedSegs(
[hidden_t_all, seq_lens_for_reverse], name + "/output-reversed")
return hidden_t_all, hidden_t_last, cell_last
if direction == 'forward':
hidden_t_all, hidden_t_last, cell_last = make_lstm(0)
# in the forward case, storage is shared between the three
# outputs. We need to decouple them so that the
# VariableLengthSequencePadding only mutates n.outputs[0]
pred_mh.net.Copy(hidden_t_last, n.outputs[1])
pred_mh.net.Copy(cell_last, n.outputs[2])
pred_mh.net = pred_mh.net.Clone(
"dummy-clone-net",
blob_remap={ hidden_t_all: n.outputs[0] }
)
elif direction == 'bidirectional':
hidden_t_all_f, hidden_t_last_f, cell_last_f = make_lstm(0)
hidden_t_all_b, hidden_t_last_b, cell_last_b = make_lstm(1)
pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
[n.outputs[0], cls.dummy_name()], axis=2)
pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
[n.outputs[1], cls.dummy_name()], axis=0)
pred_mh.net.Concat([cell_last_f, cell_last_b],
[n.outputs[2], cls.dummy_name()], axis=0)
if sequence_lens is not None:
pred_mh.net.VariableLengthSequencePadding(
[n.outputs[0], sequence_lens], [n.outputs[0]])
return Caffe2Ops(list(pred_mh.Proto().op),
list(init_net.Proto().op),
list(pred_mh.Proto().external_input))
@classmethod
def _create_gru(cls, init_model, pred_model, n, opset_version):
assert init_model is not None, "cannot convert GRUs without access to the full model"
assert pred_model is not None, "cannot convert GRUs without access to the full model"
attrs = dict(n.attrs) # make a copy, which is safe to mutate
hidden_size = attrs.pop('hidden_size')
linear_before_reset = attrs.pop('linear_before_reset', 0)
direction = force_unicode(attrs.pop('direction', 'forward'))
assert not attrs, "unsupported GRU attributes: " + str(attrs.keys())
assert direction in ['forward', 'bidirectional'], "unsupported backwards GRU"
input_blob, W, R, B, sequence_lens, initial_h = n.inputs
if sequence_lens == "":
sequence_lens = None
input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W)
if input_size is None:
raise RuntimeError("best-effort shape inference for GRU input failed")
init_net = core.Net("init-net")
pred_mh = ModelHelper()
def make_gru(direction_offset):
name = cls.dummy_name()
# input and recurrence biases are squashed together in
# onnx but not in caffe2
bias_offset = 6 * direction_offset * hidden_size
Bi = init_net.Slice(B, name + "_bias_i2h",
starts=[bias_offset + 0 * hidden_size],
ends =[bias_offset + 3 * hidden_size])
Br = init_net.Slice(B, name + "_bias_gates",
starts=[bias_offset + 3 * hidden_size],
ends =[bias_offset + 6 * hidden_size])
weight_offset = 3 * direction_offset * hidden_size
W_ = init_net.Slice(W, name + '/i2h_w_pre',
starts=[weight_offset + 0 * hidden_size, 0],
ends =[weight_offset + 3 * hidden_size,-1])
R_ = init_net.Slice(R, name + '/gates_t_w_pre',
starts=[weight_offset + 0 * hidden_size, 0],
ends =[weight_offset + 3 * hidden_size,-1])
# caffe2 has a different order from onnx. We need to rearrange
# z r h -> r z h
reforms = ((W_, 'i2h_w', True, [(0,-1)]),
(R_, 'gate_t_w', False, [(0,-1)]),
(Bi, 'i2h_b', True, []),
(Br, 'gate_t_b', False, []))
for name_from, name_to, do_concat, extra_dims in reforms:
xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')]
for i, x in enumerate([xz, xr, xh]):
dim0 = i * hidden_size, (i+1) * hidden_size
starts, ends = zip(dim0, *extra_dims)
init_net.Slice(name_from, x, starts=starts, ends=ends)
if do_concat:
init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), cls.dummy_name()], axis=0)
initial_h_sliced = name + '/initial_h'
init_net.Slice(initial_h, initial_h_sliced,
starts=[direction_offset + 0, 0, 0],
ends =[direction_offset + 1,-1,-1])
if direction_offset == 1:
if sequence_lens is not None:
seq_lens_for_reverse = sequence_lens
else:
input_shape = pred_mh.net.Shape(input_blob, name + '/input_shape')
batch_size = pred_mh.net.Slice(input_shape, name + '/batch_size_slice', starts=[1], ends=[2])
seq_len = pred_mh.net.Slice(input_shape, name + '/seq_len_slice', starts=[0], ends=[1])
dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name + '/dummy_sequence_lens', axis=0)
pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.dummy_name()], shape=[-1])
seq_lens_for_reverse = dummy_sequence_lens
if direction_offset == 1:
input = pred_mh.net.ReversePackedSegs(
[input_blob, seq_lens_for_reverse], name + "/input-reversed")
else:
input = input_blob
hidden_t_all, hidden_t_last = gru_cell.GRU(
pred_mh,
input,
sequence_lens,
[initial_h_sliced],
input_size,
hidden_size,
name,
drop_states=False,
forward_only=True,
linear_before_reset=linear_before_reset
)
if direction_offset == 1:
hidden_t_all = pred_mh.net.ReversePackedSegs(
[hidden_t_all, seq_lens_for_reverse], name + "/output-reversed")
return hidden_t_all, hidden_t_last
if direction == 'forward':
hidden_t_all, hidden_t_last = make_gru(0)
# in the forward case, storage is shared between the two
# outputs. We need to decouple them so that the
# VariableLengthSequencePadding only mutates n.outputs[0]
pred_mh.net.Copy(hidden_t_last, n.outputs[1])
pred_mh.net = pred_mh.net.Clone(
"dummy-clone-net",
blob_remap={ hidden_t_all: n.outputs[0] }
)
elif direction == 'bidirectional':
hidden_t_all_f, hidden_t_last_f = make_gru(0)
hidden_t_all_b, hidden_t_last_b = make_gru(1)
pred_mh.net.Concat([hidden_t_all_f, hidden_t_all_b],
[n.outputs[0], cls.dummy_name()], axis=2)
pred_mh.net.Concat([hidden_t_last_f, hidden_t_last_b],
[n.outputs[1], cls.dummy_name()], axis=0)
for i in range(1, len(n.outputs)):
pred_mh.net.Concat([outputs_f[i], outputs_b[i]],
[n.outputs[i], cls.dummy_name()], axis=0)
if sequence_lens is not None:
pred_mh.net.VariableLengthSequencePadding(