Optimize SoftmaxOp on CPU (#18635)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18635

Optimize SoftmaxOp on CPU

Reviewed By: houseroad

Differential Revision: D14689516

fbshipit-source-id: d2dcee2476d1a3a21f428e99bce9835f1d229d64
This commit is contained in:
Xiaomeng Yang
2019-04-10 18:45:57 -07:00
committed by Facebook Github Bot
parent 1abbee0f8e
commit 821b5f138a
10 changed files with 153 additions and 206 deletions

View File

@ -1,49 +1,29 @@
#include "caffe2/operators/softmax_op.h"
#include "caffe2/operators/softmax_shared.h"
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 {
// Implementation for the CPU context.
template <>
bool SoftmaxOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
const auto canonical_axis = X.canonical_axis_index(axis_);
const auto& X = Input(0);
const int canonical_axis = X.canonical_axis_index(axis_);
const int N = X.size_to_dim(canonical_axis);
const int D = X.size_from_dim(canonical_axis);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
float* Ydata = Y->template mutable_data<float>();
// First, get scales
const float* X_data = X.data<float>();
float* Y_data = Y->mutable_data<float>();
if (N == 0) {
return true;
}
if (!scale_.defined()) {
scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (scale_.numel() != N) {
scale_.Resize(N);
}
if (!rowmax_.defined()) {
rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (rowmax_.numel() != N) {
rowmax_.Resize(N);
}
if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D);
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
}
SoftmaxCPU(
context_,
N,
D,
X.data<float>(),
Ydata,
scale_.mutable_data<float>(),
sum_multiplier_.data<float>(),
false,
rowmax_.mutable_data<float>());
softmax_utils::SoftmaxCPU<float>(
N, D, false, X_data, Y_data, scale_.mutable_data<float>(), &context_);
return true;
}
@ -65,10 +45,12 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
math::Set<float, CPUContext>(
D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D);
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
math::Set<float, CPUContext>(
D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
}
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
@ -81,12 +63,21 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
context_.CopySameDevice<float>(Y.numel(), dYdata, dXdata);
float* scaledata = scale_.mutable_data<float>();
for (int i = 0; i < N; ++i) {
math::Dot<float, CPUContext>(D, Ydata + i * D, dYdata + i * D,
scaledata + i, &context_);
math::Dot<float, CPUContext>(
D, Ydata + i * D, dYdata + i * D, scaledata + i, &context_);
}
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans, N, D, 1, -1,
scaledata, sum_multiplier_.data<float>(), 1,
dXdata, &context_);
math::Gemm<float, CPUContext>(
CblasNoTrans,
CblasNoTrans,
N,
D,
1,
-1,
scaledata,
sum_multiplier_.data<float>(),
1,
dXdata,
&context_);
math::Mul<float, CPUContext>(Y.numel(), dXdata, Ydata, dXdata, &context_);
return true;
}
@ -184,7 +175,8 @@ class GetSoftmaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
def_.type() + "Gradient", "",
def_.type() + "Gradient",
"",
vector<string>{O(0), GO(0)},
vector<string>{GI(0)});
}
@ -192,4 +184,4 @@ class GetSoftmaxGradient : public GradientMakerBase {
REGISTER_GRADIENT(Softmax, GetSoftmaxGradient);
REGISTER_GRADIENT(SoftmaxFp16, GetSoftmaxGradient);
} // namespace caffe2
} // namespace caffe2

View File

@ -16,6 +16,7 @@ class SoftmaxOp final : public Operator<Context> {
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:

View File

@ -1,55 +0,0 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
void SoftmaxCPU(
CPUContext& context,
const int N,
const int D,
const float* Xdata,
float* Ydata,
float* scale,
const float* sum_multiplier,
bool logarithmic,
float* rowmax) {
math::RowwiseMax<float, CPUContext>(N, D, Xdata, rowmax, &context);
// Put the intermediate result X - max(X) into Y
context.template CopyFromCPU<float>(N * D, Xdata, Ydata);
// Subtract the max (for numerical reasons)
math::Gemm<float, CPUContext>(
CblasNoTrans,
CblasNoTrans,
N,
D,
1,
-1,
rowmax,
sum_multiplier,
1,
Ydata,
&context);
// Exponentiation
math::Exp<float, CPUContext>(N * D, Ydata, Ydata, &context);
math::Gemv<float, CPUContext>(
CblasNoTrans, N, D, 1, Ydata, sum_multiplier, 0, scale, &context);
// Do division
// TODO(Yangqing): maybe implement it more beautifully?
if (!logarithmic) {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] /= scale[i];
}
}
} else {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] =
Xdata[i * D + j] - rowmax[i] - log(fmaxf(scale[i], 1e-20f));
}
}
}
}
} // namespace caffe2

View File

@ -1,21 +0,0 @@
#ifndef CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
#define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
void SoftmaxCPU(
CPUContext& context,
const int N,
const int D,
const float* Xdata,
float* Ydata,
float* scale,
const float* sum_multiplier,
bool logarithmic,
float* rowmax);
} // namespace caffe2
#endif // #define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_

View File

@ -0,0 +1,38 @@
#include "caffe2/operators/softmax_utils.h"
#include "caffe2/core/context.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace softmax_utils {
#define CAFFE2_SPECIALIZED_SOFTMAX_CPU(T) \
template <> \
void SoftmaxCPU<T>( \
const int N, \
const int D, \
const bool logarithmic, \
const T* X, \
T* Y, \
T* scratch, \
CPUContext* context) { \
ConstEigenArrayMap<T> X_arr(X, D, N); \
EigenArrayMap<T> Y_arr(Y, D, N); \
EigenVectorArrayMap<T> scratch_arr(scratch, N); \
scratch_arr = X_arr.colwise().maxCoeff().transpose(); \
Y_arr = X_arr.rowwise() - scratch_arr.transpose(); \
math::Exp<T, CPUContext>(N * D, Y, Y, context); \
if (logarithmic) { \
scratch_arr += Y_arr.colwise().sum().log().transpose(); \
Y_arr = X_arr.rowwise() - scratch_arr.transpose(); \
} else { \
scratch_arr = Y_arr.colwise().sum().inverse().transpose(); \
Y_arr = Y_arr.rowwise() * scratch_arr.transpose(); \
} \
}
CAFFE2_SPECIALIZED_SOFTMAX_CPU(float)
#undef CAFFE2_SPECIALIZED_SOFTMAX_CPU
} // namespace softmax_utils
} // namespace caffe2

View File

@ -0,0 +1,23 @@
#ifndef CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
#define CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
namespace softmax_utils {
template <typename T>
void SoftmaxCPU(
int N,
int D,
bool logarithmic,
const T* X,
T* Y,
T* scratch,
CPUContext* context);
} // namespace softmax_utils
} // namespace caffe2
#endif // CAFFE2_OPERATORS_SOFTMAX_UTILS_H_

View File

@ -1,5 +1,8 @@
#include "softmax_with_loss_op.h"
#include "softmax_shared.h"
#include "caffe2/operators/softmax_with_loss_op.h"
#include <vector>
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 {
@ -12,28 +15,28 @@ REGISTER_CPU_OPERATOR(
OPERATOR_SCHEMA(SoftmaxWithLoss)
.NumInputs(2, 3)
.NumOutputs(2)
.TensorInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& in) {
ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
vector<TensorShape> out(2);
vector<TensorShape> out(2);
auto logits = in[0]; // Tensor with Shape [batch_size, num_classes]
auto labels = in[1]; // Tensor with shape [batch_size, ]
const auto canonical_axis =
canonical_axis_index_(axis, logits.dims().size());
const int batch_size =
size_to_dim_(canonical_axis, GetDimsVector(logits));
const int num_classes =
size_from_dim_(canonical_axis, GetDimsVector(logits));
auto logits = in[0]; // Tensor with Shape [batch_size, num_classes]
auto labels = in[1]; // Tensor with shape [batch_size, ]
const auto canonical_axis =
canonical_axis_index_(axis, logits.dims().size());
const int batch_size =
size_to_dim_(canonical_axis, GetDimsVector(logits));
const int num_classes =
size_from_dim_(canonical_axis, GetDimsVector(logits));
out[0].set_data_type(logits.data_type());
out[0].add_dims(batch_size);
out[0].add_dims(num_classes);
out[0].set_data_type(logits.data_type());
out[0].add_dims(batch_size);
out[0].add_dims(num_classes);
return out;
})
return out;
})
.SetDoc(R"DOC(
Combined Softmax and Cross-Entropy loss operator. The operator first computes the softmax normalized values for each layer in the batch of the given input, then computes cross-entropy loss. This operator is numerically more stable than separate `Softmax` and `CrossEntropy` ops. The inputs are a 2-D tensor `logits` of size (batch_size x input_feature_dimensions), which represents the unscaled log probabilities, and a 1-dimensional integer `labels` tensor for ground truth. An optional third input blob (`weight_tensor`) can be used to weight the samples for the loss, which is useful if the training set is unbalanced. This operator outputs a `softmax` tensor which contains the probability for each label for each example (same shape is `logits` input), and a scalar `loss` value, which is the averaged cross-entropy loss between the softmax probabilities and the ground truth values. Use parameter `label_prob`=1 to enable inputting labels as a probability distribution.
@ -132,10 +135,18 @@ avgloss: 10.667433
</details>
)DOC")
.Arg("label_prob","*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
.Arg("axis","*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
.Arg("scale","*(type: float)* Average loss output scaling factor (must be >= 0).")
.Arg("order","*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
.Arg(
"label_prob",
"*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
.Arg(
"axis",
"*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
.Arg(
"scale",
"*(type: float)* Average loss output scaling factor (must be >= 0).")
.Arg(
"order",
"*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
.Input(0, "logits", "*(type: Tensor`<float>`)* Input tensor.")
.Input(1, "labels", "*(type: Tensor`<float>`)* Ground truth label tensor.")
.Input(
@ -178,36 +189,20 @@ bool SoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
}
}
if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D);
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
}
if (!losses_.defined()) {
losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (losses_.numel() != N) {
losses_.Resize(N);
}
if (!rowmax_.defined()) {
rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (rowmax_.numel() != N) {
rowmax_.Resize(N);
}
SoftmaxCPU(
context_,
softmax_utils::SoftmaxCPU<float>(
N,
D,
!label_prob_mode_,
X.data<float>(),
Pdata,
losses_.mutable_data<float>(),
sum_multiplier_.data<float>(),
!label_prob_mode_,
rowmax_.mutable_data<float>());
&context_);
// Then compute cross entropy
float loss_sum = 0.0;
@ -382,5 +377,5 @@ class GetSoftmaxWithLossGradient : public GradientMakerBase {
};
REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient);
}
} // namespace
} // namespace caffe2

View File

@ -1,5 +1,4 @@
#include "spatial_softmax_with_loss_op.h"
#include "softmax_shared.h"
#include "caffe2/operators/spatial_softmax_with_loss_op.h"
namespace caffe2 {

View File

@ -14,8 +14,9 @@
* limitations under the License.
*/
#include "group_spatial_softmax_op.h"
#include "caffe2/operators/softmax_shared.h"
#include "modules/detectron/group_spatial_softmax_op.h"
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 {
@ -59,18 +60,12 @@ See: https://arxiv.org/abs/1708.02002 for details.
OPERATOR_SCHEMA(GroupSpatialSoftmaxGradient)
.NumInputs(2)
.NumOutputs(1)
.Input(
0,
"scores",
"See GroupSpatialSoftmax")
.Input(0, "scores", "See GroupSpatialSoftmax")
.Input(
1,
"d_probabilities",
"Gradient of forward output 0 (probabilities).")
.Output(
0,
"d_scores",
"Gradient of forward input 0 (scores).");
.Output(0, "d_scores", "Gradient of forward input 0 (scores).");
class GetGroupSpatialSoftmaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
@ -84,4 +79,5 @@ class GetGroupSpatialSoftmaxGradient : public GradientMakerBase {
};
REGISTER_GRADIENT(GroupSpatialSoftmax, GetGroupSpatialSoftmaxGradient);
} // namespace caffe2

View File

@ -14,8 +14,9 @@
* limitations under the License.
*/
#include "softmax_focal_loss_op.h"
#include "caffe2/operators/softmax_shared.h"
#include "modules/detectron/softmax_focal_loss_op.h"
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 {
@ -46,12 +47,8 @@ See: https://arxiv.org/abs/1708.02002 for details.
.Arg(
"scale",
"(float) default 1.0; multiply the loss by this scale factor.")
.Arg(
"alpha",
"(float) default 0.25; Focal Loss's alpha hyper-parameter.")
.Arg(
"gamma",
"(float) default 1.0; Focal Loss's gamma hyper-parameter.")
.Arg("alpha", "(float) default 0.25; Focal Loss's alpha hyper-parameter.")
.Arg("gamma", "(float) default 1.0; Focal Loss's gamma hyper-parameter.")
.Arg(
"num_classes",
"(int) default 81; number of classes in each softmax group.")
@ -69,12 +66,8 @@ See: https://arxiv.org/abs/1708.02002 for details.
.Input(
2,
"normalizer",
"Scalar; the loss is normalized by 1 / max(1, normalizer)."
)
.Output(
0,
"loss",
"Scalar loss.")
"Scalar; the loss is normalized by 1 / max(1, normalizer).")
.Output(0, "loss", "Scalar loss.")
.Output(
1,
"probabilities",
@ -85,30 +78,15 @@ See: https://arxiv.org/abs/1708.02002 for details.
OPERATOR_SCHEMA(SoftmaxFocalLossGradient)
.NumInputs(5)
.NumOutputs(1)
.Input(
0,
"scores",
"See SoftmaxFocalLoss.")
.Input(
1,
"labels",
"See SoftmaxFocalLoss.")
.Input(
2,
"normalizer",
"See SoftmaxFocalLoss.")
.Input(0, "scores", "See SoftmaxFocalLoss.")
.Input(1, "labels", "See SoftmaxFocalLoss.")
.Input(2, "normalizer", "See SoftmaxFocalLoss.")
.Input(
3,
"probabilities",
"Output 1 from SoftmaxFocalLoss; See SoftmaxFocalLoss.")
.Input(
4,
"d_loss",
"Gradient of forward output 0 (loss)")
.Output(
0,
"d_scores",
"Gradient of forward input 0 (scores)");
.Input(4, "d_loss", "Gradient of forward output 0 (loss)")
.Output(0, "d_scores", "Gradient of forward input 0 (scores)");
class GetSoftmaxFocalLossGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
@ -122,4 +100,5 @@ class GetSoftmaxFocalLossGradient : public GradientMakerBase {
};
REGISTER_GRADIENT(SoftmaxFocalLoss, GetSoftmaxFocalLossGradient);
} // namespace caffe2