Files
pytorch/caffe2/operators/fully_connected_op.h
2016-11-15 00:00:46 -08:00

202 lines
5.4 KiB
C++

#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
template <typename T, class Context, class Engine = DefaultEngine>
class FullyConnectedOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
FullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
~FullyConnectedOp() {}
bool RunOnDevice() override {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& b = Input(2);
auto* Y = Output(0);
CAFFE_ENFORCE(W.ndim() == 2, W.ndim());
CAFFE_ENFORCE(b.ndim() == 1, b.ndim());
// batch size
const auto canonical_axis = X.canonical_axis_index(axis_);
const auto M = X.size_to_dim(canonical_axis);
const auto K = X.size_from_dim(canonical_axis);
const int N = W.dim32(0);
auto dimErrorString = [&]() {
return MakeString(
"Dimension mismatch: ",
"X: ",
X.dims(),
", W: ",
W.dims(),
", b: ",
b.dims(),
", axis: ",
axis_,
", M: ",
M,
", N: ",
N,
", K: ",
K);
};
// Error checking
CAFFE_ENFORCE(M == X.size() / K, dimErrorString());
CAFFE_ENFORCE(K == W.size() / W.dim32(0), dimErrorString());
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
CAFFE_ENFORCE(N == b.size(), dimErrorString());
Y_shape_cache_ = X.dims();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
Y_shape_cache_.resize(canonical_axis + 1);
Y_shape_cache_[canonical_axis] = N;
Y->Resize(Y_shape_cache_);
CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
// W * x
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasTrans,
M,
N,
K,
1,
X.template data<T>(),
W.template data<T>(),
0,
Y->template mutable_data<T>(),
&context_);
// Add bias term
if (bias_multiplier_.size() != M) {
// If the helper bias multiplier is not M, reshape and fill it with one.
bias_multiplier_.Resize(M);
math::Set<T, Context>(
M,
static_cast<T>(1),
bias_multiplier_.template mutable_data<T>(),
&context_);
}
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
N,
1,
1,
bias_multiplier_.template data<T>(),
b.template data<T>(),
1,
Y->template mutable_data<T>(),
&context_);
return true;
}
protected:
size_t axis_{1};
// A local vector to cache the output shape so we don't need to recreate
// a vector object every time we run Run().
vector<TIndex> Y_shape_cache_;
Tensor<Context> bias_multiplier_;
};
template <typename T, class Context, class Engine = DefaultEngine>
class FullyConnectedGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
FullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
~FullyConnectedGradientOp() {}
bool RunOnDevice() override {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& dY = Input(2);
CAFFE_ENFORCE(W.ndim() == 2, W.ndim());
// batch size
const auto canonical_axis = X.canonical_axis_index(axis_);
const int M = X.size_to_dim(canonical_axis);
const int K = X.size_from_dim(canonical_axis);
const int N = W.dim32(0);
CAFFE_ENFORCE(M * K == X.size());
CAFFE_ENFORCE(K * N == W.size());
auto* dW = Output(0);
auto* db = Output(1);
dW->ResizeLike(W);
db->Resize(N);
// Compute dW
math::Gemm<T, Context, Engine>(
CblasTrans,
CblasNoTrans,
N,
K,
M,
1,
dY.template data<T>(),
X.template data<T>(),
0,
dW->template mutable_data<T>(),
&context_);
if (bias_multiplier_.size() != M) {
// If the helper bias multiplier is not M, reshape and fill it
// with one.
bias_multiplier_.Resize(M);
math::Set<T, Context>(
M,
static_cast<T>(1),
bias_multiplier_.template mutable_data<T>(),
&context_);
}
// Compute dB
math::Gemv<T, Context>(
CblasTrans,
M,
N,
1,
dY.template data<T>(),
bias_multiplier_.template data<T>(),
0,
db->template mutable_data<T>(),
&context_);
// Compute dX
if (OutputSize() == 3) {
auto* dX = Output(2);
dX->ResizeLike(X);
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
K,
N,
1,
dY.template data<T>(),
W.template data<T>(),
0,
dX->template mutable_data<T>(),
&context_);
}
return true;
}
protected:
size_t axis_{1};
Tensor<Context> bias_multiplier_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_