mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement a Caffe2 standalone LSTM operator (#17726)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17726 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17725 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17461 Implementing a standalone LSTM Operator in Caffe2 adopted from this Aten implementation: diffusion/FBS/browse/master/fbcode/caffe2/aten/src/ATen/native/RNN.cpp. The most tricky thing in this exercise was that caffe2::Tensor has no copy constructor that made it necessary to implement a custom templated copy constructor for the different Tensor containers used in the code. Also there was no way to use off-the-shelf C2 operators in my code easily so I had to copy some code that is doing basic matmul, cat, split, transpose and linear as utility functions. Two things missing: - Profiling this implementation against the current ONNXified LSTM op - Make this operator available to use in PyTorch Reviewed By: dzhulgakov Differential Revision: D14351575 fbshipit-source-id: 3b99b53212cf593c7a49e45580b5a07b90809e64
This commit is contained in:
committed by
Facebook Github Bot
parent
7d02a1fbc7
commit
f8778aef78
71
caffe2/operators/inference_lstm_op.cc
Normal file
71
caffe2/operators/inference_lstm_op.cc
Normal file
@ -0,0 +1,71 @@
|
||||
#include "caffe2/operators/inference_lstm_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
bool InferenceLSTMOp::RunOnDevice() {
|
||||
auto& _input = Input(0);
|
||||
auto& hidden_0 = Input(1);
|
||||
auto& hidden_1 = Input(2);
|
||||
std::vector<Tensor> params;
|
||||
for (int i = 3; i < InputSize(); i++) {
|
||||
params.push_back(Input(i).UnsafeSharedInstance());
|
||||
}
|
||||
auto input = batch_first_ ? transpose(_input, 0, 1, &context_)
|
||||
: _input.UnsafeSharedInstance();
|
||||
|
||||
auto cell_params = gather_params(params, has_biases_, &context_);
|
||||
auto results = _lstm_impl(
|
||||
input,
|
||||
cell_params,
|
||||
hidden_0,
|
||||
hidden_1,
|
||||
num_layers_,
|
||||
bidirectional_,
|
||||
&context_);
|
||||
|
||||
std::vector<Tensor> allOutputs(OutputSize());
|
||||
allOutputs.at(0) = copy_ctor(std::get<0>(results));
|
||||
if (batch_first_) {
|
||||
allOutputs.at(0) = transpose(allOutputs.at(0), 0, 1, &context_);
|
||||
}
|
||||
allOutputs.at(1) = copy_ctor(std::get<1>(results));
|
||||
allOutputs.at(2) = copy_ctor(std::get<2>(results));
|
||||
for (int i = 0; i < OutputSize(); i++) {
|
||||
auto output = XOutput(i, allOutputs.at(i).sizes(), dtype<float>());
|
||||
context_.CopyItemsSameDevice(
|
||||
allOutputs.at(i).dtype(),
|
||||
allOutputs.at(i).numel(),
|
||||
allOutputs.at(i).template data<float>(),
|
||||
output.template mutable_data<float>());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR(InferenceLSTM, InferenceLSTMOp);
|
||||
OPERATOR_SCHEMA(InferenceLSTM)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(3)
|
||||
.Output(0, "output", "the output of the last layer of lstm")
|
||||
.Output(1, "hidden", "hidden state at t = seq_len")
|
||||
.Output(2, "cell", "cell state at t = seq_len")
|
||||
.Arg("num_layers", "(*long*): number of layers in the lstm stack")
|
||||
.Arg("has_biases", "(*bool*): whether the cells have biases or not")
|
||||
.Arg("batch_first", "(*bool*): whether the batch is at dim 0")
|
||||
.Arg("bidirectional", "(*bool*): if bidirectional");
|
||||
NO_GRADIENT(InferenceLSTM);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
||||
C10_REGISTER_CAFFE2_OPERATOR_CPU(
|
||||
InferenceLSTM,
|
||||
(std::vector<c10::Argument>{
|
||||
c10::Argument("input_list", ListType::ofTensors()),
|
||||
c10::Argument("num_layers", IntType::get()),
|
||||
c10::Argument("has_biases", BoolType::get()),
|
||||
c10::Argument("batch_first", BoolType::get()),
|
||||
c10::Argument("bidirectional", BoolType::get())}),
|
||||
(std::vector<c10::Argument>{c10::Argument("output"),
|
||||
c10::Argument("hidden"),
|
||||
c10::Argument("cell")}),
|
||||
caffe2::InferenceLSTMOp);
|
310
caffe2/operators/inference_lstm_op.h
Normal file
310
caffe2/operators/inference_lstm_op.h
Normal file
@ -0,0 +1,310 @@
|
||||
#ifndef LSTM_OP_H_
|
||||
#define LSTM_OP_H_
|
||||
|
||||
#include <c10/core/Tensor.h>
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "lstm_utils.h"
|
||||
|
||||
C10_DECLARE_CAFFE2_OPERATOR(LSTMOp);
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
using t_tuple = std::tuple<Tensor, Tensor>;
|
||||
|
||||
struct CellParams {
|
||||
CellParams(
|
||||
const Tensor& _w_ih,
|
||||
const Tensor& _w_hh,
|
||||
const Tensor& _b_ih,
|
||||
const Tensor& _b_hh,
|
||||
CPUContext* _context) {
|
||||
initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context);
|
||||
}
|
||||
|
||||
CellParams(const CellParams& rhs) {
|
||||
initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
|
||||
}
|
||||
|
||||
CellParams& operator=(const CellParams& rhs) {
|
||||
initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void initParams(
|
||||
const Tensor& _w_ih,
|
||||
const Tensor& _w_hh,
|
||||
const Tensor& _b_ih,
|
||||
const Tensor& _b_hh,
|
||||
CPUContext* _context) {
|
||||
w_ih = copy_ctor(_w_ih);
|
||||
w_hh = copy_ctor(_w_hh);
|
||||
b_ih = copy_ctor(_b_ih);
|
||||
b_hh = copy_ctor(_b_hh);
|
||||
context = _context;
|
||||
}
|
||||
|
||||
Tensor w_ih;
|
||||
Tensor w_hh;
|
||||
Tensor b_ih; /* optional */
|
||||
Tensor b_hh; /* optional */
|
||||
CPUContext* context;
|
||||
|
||||
Tensor linear_ih(const Tensor& input) const {
|
||||
return linear(input, w_ih, b_ih, context);
|
||||
}
|
||||
Tensor linear_hh(const Tensor& h) const {
|
||||
return linear(h, w_hh, b_hh, context);
|
||||
}
|
||||
};
|
||||
|
||||
struct LSTMCell {
|
||||
explicit LSTMCell(CPUContext* context) : context_(context) {}
|
||||
t_tuple operator()(
|
||||
const Tensor& input,
|
||||
const t_tuple& hidden,
|
||||
const CellParams& params) const {
|
||||
const auto& hx = std::get<0>(hidden);
|
||||
const auto& cx = std::get<1>(hidden);
|
||||
auto linear_ih = params.linear_ih(input);
|
||||
auto linear_hh = params.linear_hh(hx);
|
||||
auto gates = add(linear_ih, linear_hh, context_);
|
||||
auto chunked_gates = chunk(gates, 4, 1, context_);
|
||||
auto ingate = sigmoid(chunked_gates[0]);
|
||||
auto forgetgate = sigmoid(chunked_gates[1]);
|
||||
auto cellgate = tanh(chunked_gates[2], context_);
|
||||
auto outgate = sigmoid(chunked_gates[3]);
|
||||
|
||||
auto cy =
|
||||
add(mul(forgetgate, cx, context_),
|
||||
mul(ingate, cellgate, context_),
|
||||
context_);
|
||||
auto hy = mul(outgate, tanh(cy, context_), context_);
|
||||
return std::make_tuple(std::move(hy), std::move(cy));
|
||||
}
|
||||
CPUContext* context_;
|
||||
};
|
||||
|
||||
template <typename output_type, typename hidden_type>
|
||||
struct LayerOutput {
|
||||
output_type outputs;
|
||||
hidden_type final_hidden;
|
||||
|
||||
LayerOutput(const output_type& _outputs, const hidden_type& _hidden) {
|
||||
outputs = copy_ctor(_outputs);
|
||||
final_hidden = copy_ctor(_hidden);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename hidden_type, typename param_type>
|
||||
struct Layer {
|
||||
using output_type = LayerOutput<Tensor, hidden_type>;
|
||||
virtual ~Layer() {}
|
||||
virtual output_type operator()(
|
||||
const Tensor& input,
|
||||
const hidden_type& input_hidden,
|
||||
const param_type& params) const = 0;
|
||||
};
|
||||
|
||||
struct FullLSTMLayer : Layer<t_tuple, CellParams> {
|
||||
FullLSTMLayer(LSTMCell& cell, CPUContext* context)
|
||||
: cell_(cell), context_(context) {}
|
||||
|
||||
LayerOutput<std::vector<Tensor>, t_tuple> operator()(
|
||||
const std::vector<Tensor>& step_inputs,
|
||||
const std::tuple<Tensor, Tensor>& input_hidden,
|
||||
const CellParams& params) const {
|
||||
std::vector<Tensor> step_outputs;
|
||||
auto hidden = copy_ctor(input_hidden);
|
||||
|
||||
for (size_t i = 0; i < step_inputs.size(); i++) {
|
||||
hidden = cell_(step_inputs[i], hidden, params);
|
||||
step_outputs.push_back(copy_ctor(std::get<0>(hidden)));
|
||||
}
|
||||
|
||||
return {step_outputs, hidden};
|
||||
}
|
||||
|
||||
LayerOutput<Tensor, t_tuple> operator()(
|
||||
const Tensor& inputs,
|
||||
const std::tuple<Tensor, Tensor>& input_hidden,
|
||||
const CellParams& params) const override {
|
||||
auto unstacked_output =
|
||||
(*this)(unbind(inputs, 0, context_), input_hidden, params);
|
||||
return {stack(unstacked_output.outputs, 0, context_),
|
||||
unstacked_output.final_hidden};
|
||||
}
|
||||
LSTMCell cell_;
|
||||
CPUContext* context_;
|
||||
};
|
||||
|
||||
struct FullBidirectionalLSTMLayer
|
||||
: Layer<std::pair<t_tuple, t_tuple>, std::pair<CellParams, CellParams>> {
|
||||
using bidir_hidden_type = std::pair<t_tuple, t_tuple>;
|
||||
using param_type = std::pair<CellParams, CellParams>;
|
||||
using output_type = LayerOutput<Tensor, bidir_hidden_type>;
|
||||
|
||||
FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context)
|
||||
: layer_(cell, context), context_(context) {}
|
||||
|
||||
output_type operator()(
|
||||
const Tensor& input,
|
||||
const bidir_hidden_type& input_hidden,
|
||||
const param_type& params) const override {
|
||||
std::vector<Tensor> outputs;
|
||||
auto step_inputs = unbind(input, 0, context_);
|
||||
auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
|
||||
auto fw_output = stack(fw_result.outputs, 0, context_);
|
||||
outputs.push_back(copy_ctor(fw_output));
|
||||
auto rev_step_inputs = reverse(std::move(step_inputs));
|
||||
auto rev_result =
|
||||
layer_(rev_step_inputs, input_hidden.second, params.second);
|
||||
std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
|
||||
auto rev_output = stack(rev_result.outputs, 0, context_);
|
||||
outputs.push_back(copy_ctor(rev_output));
|
||||
return {cat(outputs, fw_output.dim() - 1, context_),
|
||||
std::make_pair(
|
||||
std::move(fw_result.final_hidden),
|
||||
std::move(rev_result.final_hidden))};
|
||||
}
|
||||
|
||||
inline std::vector<Tensor> reverse(std::vector<Tensor>&& x) const {
|
||||
std::reverse(x.begin(), x.end());
|
||||
return std::move(x);
|
||||
}
|
||||
|
||||
private:
|
||||
FullLSTMLayer layer_;
|
||||
CPUContext* context_;
|
||||
};
|
||||
|
||||
template <typename hidden_type, typename weight_type>
|
||||
LayerOutput<Tensor, std::vector<hidden_type>> apply_layer_stack(
|
||||
const Layer<hidden_type, weight_type>& layer,
|
||||
const Tensor& input,
|
||||
const std::vector<hidden_type>& hiddens,
|
||||
const std::vector<weight_type>& weights,
|
||||
int64_t num_layers) {
|
||||
CAFFE_ENFORCE(
|
||||
num_layers == hiddens.size(),
|
||||
"Expected more hidden states in stacked_rnn");
|
||||
CAFFE_ENFORCE(
|
||||
num_layers == weights.size(), "Expected more weights in stacked_rnn");
|
||||
|
||||
auto layer_input = input.UnsafeSharedInstance();
|
||||
auto hidden_it = hiddens.begin();
|
||||
auto weight_it = weights.begin();
|
||||
std::vector<hidden_type> final_hiddens(num_layers);
|
||||
for (int64_t l = 0; l < num_layers; ++l) {
|
||||
auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
|
||||
final_hiddens.at(l) = std::move(layer_output.final_hidden);
|
||||
layer_input = std::move(layer_output.outputs);
|
||||
}
|
||||
return {layer_input, final_hiddens};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _lstm_impl(
|
||||
const Tensor& input,
|
||||
const std::vector<CellParams>& params,
|
||||
const Tensor& hx,
|
||||
const Tensor& cx,
|
||||
int64_t num_layers,
|
||||
bool bidirectional,
|
||||
CPUContext* context) {
|
||||
using stack_output = LayerOutput<Tensor, std::vector<t_tuple>>;
|
||||
auto layer_hx = unbind(hx, 0, context);
|
||||
auto layer_cx = unbind(cx, 0, context);
|
||||
int64_t total_layers = layer_hx.size();
|
||||
std::vector<std::tuple<Tensor, Tensor>> hiddens;
|
||||
hiddens.reserve(total_layers);
|
||||
for (int64_t i = 0; i < total_layers; ++i) {
|
||||
hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
|
||||
}
|
||||
LSTMCell cell(context);
|
||||
std::shared_ptr<stack_output> stack_output_ptr;
|
||||
if (bidirectional) {
|
||||
auto bidir_result = apply_layer_stack(
|
||||
FullBidirectionalLSTMLayer{cell, context},
|
||||
input,
|
||||
pair_vec(hiddens),
|
||||
pair_vec(params),
|
||||
num_layers);
|
||||
stack_output_ptr.reset(new stack_output(
|
||||
bidir_result.outputs,
|
||||
unpair_vec(std::move(bidir_result.final_hidden))));
|
||||
} else {
|
||||
auto result = apply_layer_stack(
|
||||
FullLSTMLayer{cell, context}, input, hiddens, params, num_layers);
|
||||
stack_output_ptr = std::make_shared<stack_output>(std::move(result));
|
||||
}
|
||||
|
||||
std::vector<Tensor> hy, cy;
|
||||
hy.reserve(total_layers);
|
||||
cy.reserve(total_layers);
|
||||
for (auto& hidden : stack_output_ptr->final_hidden) {
|
||||
hy.push_back(std::move(std::get<0>(hidden)));
|
||||
cy.push_back(std::move(std::get<1>(hidden)));
|
||||
}
|
||||
return std::make_tuple(
|
||||
std::move(stack_output_ptr->outputs),
|
||||
stack(hy, 0, context),
|
||||
stack(cy, 0, context));
|
||||
}
|
||||
|
||||
// Parses a flat list of parameter tensors into a list of CellParams
|
||||
std::vector<CellParams> gather_params(
|
||||
const std::vector<Tensor>& params,
|
||||
bool has_biases,
|
||||
CPUContext* context) {
|
||||
Tensor undefined;
|
||||
std::vector<CellParams> result;
|
||||
if (has_biases) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
params.size() % 4, 0, "got an incorrect number of LSTM parameters");
|
||||
for (size_t i = 0; i < params.size(); i += 4) {
|
||||
result.emplace_back(
|
||||
params[i], params[i + 1], params[i + 2], params[i + 3], context);
|
||||
}
|
||||
} else {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
params.size() % 2, 0, "got an incorrect number of LSTM parameters");
|
||||
for (size_t i = 0; i < params.size(); i += 2) {
|
||||
result.emplace_back(
|
||||
params[i], params[i + 1], undefined, undefined, context);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
class InferenceLSTMOp : public Operator<CPUContext> {
|
||||
public:
|
||||
template <class... Args>
|
||||
explicit InferenceLSTMOp(Args&&... args)
|
||||
: Operator(std::forward<Args>(args)...),
|
||||
num_layers_(this->template GetSingleArgument<int64_t>("num_layers", 1)),
|
||||
bidirectional_(
|
||||
this->template GetSingleArgument<bool>("bidirectional", false)),
|
||||
has_biases_(this->template GetSingleArgument<bool>("has_biases", true)),
|
||||
batch_first_(
|
||||
this->template GetSingleArgument<bool>("batch_first", false)) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
int64_t num_layers_;
|
||||
bool bidirectional_;
|
||||
bool has_biases_;
|
||||
bool batch_first_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
#endif // LSTM_OP_H_
|
318
caffe2/operators/lstm_utils.h
Normal file
318
caffe2/operators/lstm_utils.h
Normal file
@ -0,0 +1,318 @@
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
using t_tuple = std::tuple<Tensor, Tensor>;
|
||||
|
||||
template <typename T>
|
||||
T copy_ctor(const T& x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <>
|
||||
Tensor copy_ctor(const Tensor& X) {
|
||||
return X.UnsafeSharedInstance();
|
||||
}
|
||||
|
||||
template <>
|
||||
t_tuple copy_ctor(const t_tuple& X) {
|
||||
return std::make_tuple(copy_ctor(std::get<0>(X)), copy_ctor(std::get<1>(X)));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::pair<t_tuple, t_tuple> copy_ctor(const std::pair<t_tuple, t_tuple>& X) {
|
||||
return std::make_pair(copy_ctor(X.first), copy_ctor(X.second));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<Tensor> copy_ctor(const std::vector<Tensor>& X) {
|
||||
std::vector<Tensor> Y(X.size());
|
||||
std::transform(X.begin(), X.end(), Y.begin(), [](const Tensor& x) {
|
||||
return copy_ctor(x);
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<t_tuple> copy_ctor(const std::vector<t_tuple>& X) {
|
||||
std::vector<t_tuple> Y(X.size());
|
||||
std::transform(X.begin(), X.end(), Y.begin(), [](const t_tuple& x) {
|
||||
return copy_ctor(x);
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::pair<t_tuple, t_tuple>> copy_ctor(
|
||||
const std::vector<std::pair<t_tuple, t_tuple>>& X) {
|
||||
std::vector<std::pair<t_tuple, t_tuple>> Y(X.size());
|
||||
std::transform(
|
||||
X.begin(), X.end(), Y.begin(), [](const std::pair<t_tuple, t_tuple>& x) {
|
||||
return copy_ctor(x);
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Gathers every two elements of a vector in a vector of pairs
|
||||
template <typename T>
|
||||
static std::vector<std::pair<T, T>> pair_vec(const std::vector<T>& vals) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
vals.size() % 2,
|
||||
0,
|
||||
"Odd number of params or hiddens given to a bidirectional RNN");
|
||||
std::vector<std::pair<T, T>> result;
|
||||
result.reserve(vals.size() / 2);
|
||||
for (int64_t i = 0; i < vals.size(); i += 2) {
|
||||
result.emplace_back(copy_ctor(vals[i]), copy_ctor(vals[i + 1]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Flattens a vector of pairs
|
||||
template <typename T>
|
||||
static std::vector<T> unpair_vec(std::vector<std::pair<T, T>>&& vals) {
|
||||
std::vector<T> result;
|
||||
result.reserve(vals.size() * 2);
|
||||
for (int64_t i = 0; i < vals.size(); i++) {
|
||||
result.push_back(std::move(vals[i].first));
|
||||
result.push_back(std::move(vals[i].second));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor matmul(const Tensor& X, const Tensor& W, CPUContext* context) {
|
||||
const auto canonical_axis = X.canonical_axis_index(1);
|
||||
const auto M = X.size_to_dim(canonical_axis);
|
||||
const auto K = X.size_from_dim(canonical_axis);
|
||||
const auto canonical_axis_w = W.canonical_axis_index(1);
|
||||
const int N = W.size_to_dim(canonical_axis_w);
|
||||
auto output_size = X.sizes().vec();
|
||||
output_size.resize(canonical_axis + 1);
|
||||
output_size[canonical_axis] = N;
|
||||
Tensor C(output_size, CPU);
|
||||
math::Gemm<float, CPUContext>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1,
|
||||
X.template data<float>(),
|
||||
W.template data<float>(),
|
||||
0,
|
||||
C.template mutable_data<float>(),
|
||||
context);
|
||||
return C;
|
||||
}
|
||||
|
||||
Tensor
|
||||
linear(const Tensor& X, const Tensor& W, const Tensor& B, CPUContext* context) {
|
||||
auto output = matmul(X, W, context);
|
||||
if (B) {
|
||||
const auto canonical_axis = X.canonical_axis_index(1);
|
||||
const auto M = X.size_to_dim(canonical_axis);
|
||||
const auto canonical_axis_w = W.canonical_axis_index(1);
|
||||
const int N = W.size_to_dim(canonical_axis_w);
|
||||
auto bias_multiplier_ = caffe2::empty({M}, CPU);
|
||||
math::Set<float, CPUContext>(
|
||||
M, 1, bias_multiplier_.template mutable_data<float>(), context);
|
||||
math::Gemm<float, CPUContext>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
N,
|
||||
1,
|
||||
1,
|
||||
bias_multiplier_.template data<float>(),
|
||||
B.template data<float>(),
|
||||
1,
|
||||
output.template mutable_data<float>(),
|
||||
context);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<Tensor>
|
||||
chunk(const Tensor& input, int chunks, int axis, CPUContext* context) {
|
||||
int canonical_axis = input.canonical_axis_index(axis);
|
||||
CAFFE_ENFORCE_LT(
|
||||
canonical_axis, input.dim(), "Axis not in input ndim range.");
|
||||
const int input_channels = input.dim32(canonical_axis);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
input_channels % chunks,
|
||||
0,
|
||||
"input channels should be divisible by the number of chunks.");
|
||||
auto split_size = input_channels / chunks;
|
||||
vector<int64_t> output_dims(input.sizes().vec());
|
||||
int before = 1, after = 1;
|
||||
for (int i = 0; i < canonical_axis; ++i) {
|
||||
before *= input.dim32(i);
|
||||
}
|
||||
for (int i = canonical_axis + 1; i < input.dim(); ++i) {
|
||||
after *= input.dim32(i);
|
||||
}
|
||||
size_t input_offset = 0;
|
||||
std::vector<Tensor> outputs;
|
||||
for (int i = 0; i < chunks; ++i) {
|
||||
auto axis_dim = split_size;
|
||||
output_dims[canonical_axis] = split_size;
|
||||
Tensor output(output_dims, CPU);
|
||||
math::CopyMatrix<CPUContext>(
|
||||
input.itemsize(),
|
||||
before,
|
||||
axis_dim * after,
|
||||
static_cast<const char*>(input.raw_data()) + input_offset,
|
||||
input.dim32(canonical_axis) * after,
|
||||
output.raw_mutable_data(input.dtype()),
|
||||
axis_dim * after,
|
||||
context,
|
||||
input.dtype().copy());
|
||||
input_offset += axis_dim * after * input.itemsize();
|
||||
outputs.push_back(std::move(output));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<Tensor> unbind(const Tensor& input, int axis, CPUContext* context) {
|
||||
// 1 - Chunk the input tensor along the given axis into N chunks where
|
||||
// N is the dim(axis)
|
||||
auto chunks = chunk(input, input.sizes()[axis], axis, context);
|
||||
// 2 - Compute new dimensions
|
||||
std::vector<int64_t> newDims = input.sizes().vec();
|
||||
newDims.erase(newDims.begin() + axis);
|
||||
|
||||
// 3 - Reshape chunks to drop the extra dimension
|
||||
for (int i = 0; i < chunks.size(); i++) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
chunks[i].sizes()[axis], 1, "Got an unexpected chunk size");
|
||||
chunks[i].Reshape(newDims);
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
Tensor
|
||||
cat(const std::vector<Tensor>& tensorList, int axis, CPUContext* context) {
|
||||
// Adopted from C2's concat operator
|
||||
auto input_zero = copy_ctor(tensorList.at(0));
|
||||
vector<int64_t> outputDims(input_zero.sizes().vec());
|
||||
CAFFE_ENFORCE(outputDims.size() > 0);
|
||||
for (int i = 1; i < tensorList.size(); i++) {
|
||||
CAFFE_ENFORCE(input_zero.dtype() == tensorList.at(i).dtype());
|
||||
outputDims[axis] += tensorList.at(i).sizes()[axis];
|
||||
}
|
||||
auto output_channels = outputDims[axis];
|
||||
Tensor output(outputDims, CPU);
|
||||
int before = 1, after = 1;
|
||||
for (int i = 0; i < tensorList.at(0).dim(); ++i) {
|
||||
if (i == axis) {
|
||||
continue;
|
||||
}
|
||||
int dim = input_zero.dim32(i);
|
||||
if (i < axis) {
|
||||
before *= dim;
|
||||
} else {
|
||||
after *= dim;
|
||||
}
|
||||
}
|
||||
size_t output_offset = 0;
|
||||
for (const auto& input : tensorList) {
|
||||
auto axis_dim = input.dim32(axis);
|
||||
math::CopyMatrix<CPUContext>(
|
||||
input.itemsize(),
|
||||
before,
|
||||
axis_dim * after,
|
||||
input.raw_data(),
|
||||
axis_dim * after,
|
||||
static_cast<char*>(output.raw_mutable_data(input_zero.dtype())) +
|
||||
output_offset,
|
||||
output_channels * after,
|
||||
context,
|
||||
input_zero.dtype().copy());
|
||||
output_offset += axis_dim * after * input.itemsize();
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor
|
||||
stack(const std::vector<Tensor>& tensorList, int axis, CPUContext* context) {
|
||||
// 1 - Compute new dimensions
|
||||
std::vector<int64_t> newDims(tensorList[0].sizes().vec());
|
||||
std::vector<Tensor> expandedTensorList;
|
||||
newDims.insert(newDims.begin() + axis, 1);
|
||||
for (int i = 0; i < tensorList.size(); i++) {
|
||||
expandedTensorList.emplace_back(tensorList[i].Clone());
|
||||
expandedTensorList.at(i).Reshape(newDims);
|
||||
}
|
||||
return cat(expandedTensorList, axis, context);
|
||||
}
|
||||
|
||||
Tensor sigmoid(const Tensor& X) {
|
||||
Tensor Y(X.sizes(), CPU);
|
||||
auto N = X.numel();
|
||||
EigenVectorArrayMap<float>(Y.template mutable_data<float>(), N) = 1.0 /
|
||||
(1.0 +
|
||||
(-ConstEigenVectorArrayMap<float>(X.template data<float>(), N)).exp());
|
||||
return Y;
|
||||
}
|
||||
|
||||
Tensor tanh(const Tensor& X, CPUContext* context) {
|
||||
Tensor Y(X.sizes(), CPU);
|
||||
math::Tanh<float, CPUContext>(
|
||||
X.numel(),
|
||||
X.template data<float>(),
|
||||
Y.template mutable_data<float>(),
|
||||
context);
|
||||
return Y;
|
||||
}
|
||||
|
||||
Tensor add(const Tensor& X, const Tensor& Y, CPUContext* context) {
|
||||
Tensor Z(X.sizes().vec(), CPU);
|
||||
math::Add<float, CPUContext>(
|
||||
X.numel(),
|
||||
X.template data<float>(),
|
||||
Y.template data<float>(),
|
||||
Z.template mutable_data<float>(),
|
||||
context);
|
||||
return Z;
|
||||
}
|
||||
|
||||
Tensor mul(const Tensor& X, const Tensor& Y, CPUContext* context) {
|
||||
Tensor Z(X.sizes().vec(), CPU);
|
||||
math::Mul<float, CPUContext>(
|
||||
X.numel(),
|
||||
X.template data<float>(),
|
||||
Y.template data<float>(),
|
||||
Z.template mutable_data<float>(),
|
||||
context);
|
||||
return Z;
|
||||
}
|
||||
|
||||
Tensor transpose(const Tensor& X, int dim0, int dim1, CPUContext* context) {
|
||||
int ndim = X.dim();
|
||||
CAFFE_ENFORCE(ndim > dim0 && ndim > dim1, "Invalid transpose dimensions");
|
||||
std::vector<int> axes(ndim);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[dim0], axes[dim1]);
|
||||
std::vector<int> Y_dims(ndim);
|
||||
std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
Y_dims[i] = X_dims[axes[i]];
|
||||
}
|
||||
Tensor Y(Y_dims, CPU);
|
||||
math::Transpose<float, CPUContext>(
|
||||
ndim,
|
||||
X_dims.data(),
|
||||
axes.data(),
|
||||
X.template data<float>(),
|
||||
Y.template mutable_data<float>(),
|
||||
context);
|
||||
return Y;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace caffe2
|
72
caffe2/python/test/inference_lstm_op_test.py
Normal file
72
caffe2/python/test/inference_lstm_op_test.py
Normal file
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
import inspect
|
||||
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
import torch
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.test_util import TestCase
|
||||
from hypothesis import given
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TestC2LSTM(TestCase):
|
||||
@given(
|
||||
bsz=st.integers(1, 5),
|
||||
seq_lens=st.integers(1, 6),
|
||||
emb_lens=st.integers(5, 10),
|
||||
hidden_size=st.integers(3, 7),
|
||||
num_layers=st.integers(1, 4),
|
||||
has_biases=st.booleans(),
|
||||
is_bidirectional=st.booleans(),
|
||||
batch_first=st.booleans(),
|
||||
)
|
||||
def test_c2_lstm(
|
||||
self,
|
||||
bsz,
|
||||
seq_lens,
|
||||
emb_lens,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
has_biases,
|
||||
is_bidirectional,
|
||||
batch_first,
|
||||
):
|
||||
net = core.Net("test_net")
|
||||
num_directions = 2 if is_bidirectional else 1
|
||||
py_lstm = nn.LSTM(
|
||||
emb_lens,
|
||||
hidden_size,
|
||||
batch_first=batch_first,
|
||||
bidirectional=is_bidirectional,
|
||||
bias=has_biases,
|
||||
num_layers=num_layers,
|
||||
)
|
||||
|
||||
hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
|
||||
|
||||
if batch_first:
|
||||
inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
|
||||
else:
|
||||
inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)
|
||||
|
||||
py_results = py_lstm(torch.from_numpy(inputs))
|
||||
lstm_in = [
|
||||
torch.from_numpy(inputs),
|
||||
torch.from_numpy(hx),
|
||||
torch.from_numpy(hx),
|
||||
] + [param.detach() for param in py_lstm._flat_weights]
|
||||
|
||||
c2_results = torch.ops._caffe2.InferenceLSTM(
|
||||
lstm_in, num_layers, has_biases, batch_first, is_bidirectional
|
||||
)
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
py_results[0].detach().numpy(), c2_results[0].detach().numpy()
|
||||
)
|
||||
np.testing.assert_array_almost_equal(
|
||||
py_results[1][0].detach().numpy(), c2_results[1].detach().numpy()
|
||||
)
|
||||
np.testing.assert_array_almost_equal(
|
||||
py_results[1][1].detach().numpy(), c2_results[2].detach().numpy()
|
||||
)
|
Reference in New Issue
Block a user