mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
202 lines
5.4 KiB
C++
202 lines
5.4 KiB
C++
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
|
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
|
|
template <typename T, class Context, class Engine = DefaultEngine>
|
|
class FullyConnectedOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
FullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws),
|
|
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
|
|
~FullyConnectedOp() {}
|
|
|
|
bool RunOnDevice() override {
|
|
const auto& X = Input(0);
|
|
const auto& W = Input(1);
|
|
const auto& b = Input(2);
|
|
auto* Y = Output(0);
|
|
CAFFE_ENFORCE(W.ndim() == 2, W.ndim());
|
|
CAFFE_ENFORCE(b.ndim() == 1, b.ndim());
|
|
// batch size
|
|
const auto canonical_axis = X.canonical_axis_index(axis_);
|
|
const auto M = X.size_to_dim(canonical_axis);
|
|
const auto K = X.size_from_dim(canonical_axis);
|
|
const int N = W.dim32(0);
|
|
|
|
auto dimErrorString = [&]() {
|
|
return MakeString(
|
|
"Dimension mismatch: ",
|
|
"X: ",
|
|
X.dims(),
|
|
", W: ",
|
|
W.dims(),
|
|
", b: ",
|
|
b.dims(),
|
|
", axis: ",
|
|
axis_,
|
|
", M: ",
|
|
M,
|
|
", N: ",
|
|
N,
|
|
", K: ",
|
|
K);
|
|
};
|
|
|
|
// Error checking
|
|
CAFFE_ENFORCE(M == X.size() / K, dimErrorString());
|
|
CAFFE_ENFORCE(K == W.size() / W.dim32(0), dimErrorString());
|
|
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
|
|
CAFFE_ENFORCE(N == b.size(), dimErrorString());
|
|
|
|
Y_shape_cache_ = X.dims();
|
|
// This is an invariant of canonical_axis, so we can DCHECK.
|
|
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
|
|
Y_shape_cache_.resize(canonical_axis + 1);
|
|
Y_shape_cache_[canonical_axis] = N;
|
|
Y->Resize(Y_shape_cache_);
|
|
CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
|
|
|
|
// W * x
|
|
math::Gemm<T, Context, Engine>(
|
|
CblasNoTrans,
|
|
CblasTrans,
|
|
M,
|
|
N,
|
|
K,
|
|
1,
|
|
X.template data<T>(),
|
|
W.template data<T>(),
|
|
0,
|
|
Y->template mutable_data<T>(),
|
|
&context_);
|
|
// Add bias term
|
|
if (bias_multiplier_.size() != M) {
|
|
// If the helper bias multiplier is not M, reshape and fill it with one.
|
|
bias_multiplier_.Resize(M);
|
|
math::Set<T, Context>(
|
|
M,
|
|
static_cast<T>(1),
|
|
bias_multiplier_.template mutable_data<T>(),
|
|
&context_);
|
|
}
|
|
math::Gemm<T, Context, Engine>(
|
|
CblasNoTrans,
|
|
CblasNoTrans,
|
|
M,
|
|
N,
|
|
1,
|
|
1,
|
|
bias_multiplier_.template data<T>(),
|
|
b.template data<T>(),
|
|
1,
|
|
Y->template mutable_data<T>(),
|
|
&context_);
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
size_t axis_{1};
|
|
// A local vector to cache the output shape so we don't need to recreate
|
|
// a vector object every time we run Run().
|
|
vector<TIndex> Y_shape_cache_;
|
|
Tensor<Context> bias_multiplier_;
|
|
};
|
|
|
|
template <typename T, class Context, class Engine = DefaultEngine>
|
|
class FullyConnectedGradientOp : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
FullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws),
|
|
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
|
|
~FullyConnectedGradientOp() {}
|
|
|
|
bool RunOnDevice() override {
|
|
const auto& X = Input(0);
|
|
const auto& W = Input(1);
|
|
const auto& dY = Input(2);
|
|
CAFFE_ENFORCE(W.ndim() == 2, W.ndim());
|
|
// batch size
|
|
const auto canonical_axis = X.canonical_axis_index(axis_);
|
|
const int M = X.size_to_dim(canonical_axis);
|
|
const int K = X.size_from_dim(canonical_axis);
|
|
const int N = W.dim32(0);
|
|
CAFFE_ENFORCE(M * K == X.size());
|
|
CAFFE_ENFORCE(K * N == W.size());
|
|
|
|
auto* dW = Output(0);
|
|
auto* db = Output(1);
|
|
dW->ResizeLike(W);
|
|
db->Resize(N);
|
|
|
|
// Compute dW
|
|
math::Gemm<T, Context, Engine>(
|
|
CblasTrans,
|
|
CblasNoTrans,
|
|
N,
|
|
K,
|
|
M,
|
|
1,
|
|
dY.template data<T>(),
|
|
X.template data<T>(),
|
|
0,
|
|
dW->template mutable_data<T>(),
|
|
&context_);
|
|
if (bias_multiplier_.size() != M) {
|
|
// If the helper bias multiplier is not M, reshape and fill it
|
|
// with one.
|
|
bias_multiplier_.Resize(M);
|
|
math::Set<T, Context>(
|
|
M,
|
|
static_cast<T>(1),
|
|
bias_multiplier_.template mutable_data<T>(),
|
|
&context_);
|
|
}
|
|
// Compute dB
|
|
math::Gemv<T, Context>(
|
|
CblasTrans,
|
|
M,
|
|
N,
|
|
1,
|
|
dY.template data<T>(),
|
|
bias_multiplier_.template data<T>(),
|
|
0,
|
|
db->template mutable_data<T>(),
|
|
&context_);
|
|
|
|
// Compute dX
|
|
if (OutputSize() == 3) {
|
|
auto* dX = Output(2);
|
|
dX->ResizeLike(X);
|
|
math::Gemm<T, Context, Engine>(
|
|
CblasNoTrans,
|
|
CblasNoTrans,
|
|
M,
|
|
K,
|
|
N,
|
|
1,
|
|
dY.template data<T>(),
|
|
W.template data<T>(),
|
|
0,
|
|
dX->template mutable_data<T>(),
|
|
&context_);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
size_t axis_{1};
|
|
Tensor<Context> bias_multiplier_;
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|