Optimize SoftmaxOp on CPU (#18635)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18635

Optimize SoftmaxOp on CPU

Reviewed By: houseroad

Differential Revision: D14689516

fbshipit-source-id: d2dcee2476d1a3a21f428e99bce9835f1d229d64
This commit is contained in:
Xiaomeng Yang
2019-04-10 18:45:57 -07:00
committed by Facebook Github Bot
parent 1abbee0f8e
commit 821b5f138a
10 changed files with 153 additions and 206 deletions

View File

@ -1,49 +1,29 @@
#include "caffe2/operators/softmax_op.h" #include "caffe2/operators/softmax_op.h"
#include "caffe2/operators/softmax_shared.h"
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 { namespace caffe2 {
// Implementation for the CPU context. // Implementation for the CPU context.
template <> template <>
bool SoftmaxOp<float, CPUContext>::RunOnDevice() { bool SoftmaxOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0); const auto& X = Input(0);
const int canonical_axis = X.canonical_axis_index(axis_);
const auto canonical_axis = X.canonical_axis_index(axis_);
const int N = X.size_to_dim(canonical_axis); const int N = X.size_to_dim(canonical_axis);
const int D = X.size_from_dim(canonical_axis); const int D = X.size_from_dim(canonical_axis);
auto* Y = Output(0, X.sizes(), at::dtype<float>()); auto* Y = Output(0, X.sizes(), at::dtype<float>());
float* Ydata = Y->template mutable_data<float>(); const float* X_data = X.data<float>();
// First, get scales float* Y_data = Y->mutable_data<float>();
if (N == 0) {
return true;
}
if (!scale_.defined()) { if (!scale_.defined()) {
scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU)); scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (scale_.numel() != N) { } else if (scale_.numel() != N) {
scale_.Resize(N); scale_.Resize(N);
} }
softmax_utils::SoftmaxCPU<float>(
if (!rowmax_.defined()) { N, D, false, X_data, Y_data, scale_.mutable_data<float>(), &context_);
rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (rowmax_.numel() != N) {
rowmax_.Resize(N);
}
if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D);
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
}
SoftmaxCPU(
context_,
N,
D,
X.data<float>(),
Ydata,
scale_.mutable_data<float>(),
sum_multiplier_.data<float>(),
false,
rowmax_.mutable_data<float>());
return true; return true;
} }
@ -65,10 +45,12 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
if (!sum_multiplier_.defined()) { if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU)); sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_); math::Set<float, CPUContext>(
D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) { } else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D); sum_multiplier_.Resize(D);
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_); math::Set<float, CPUContext>(
D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} }
auto* dX = Output(0, Y.sizes(), at::dtype<float>()); auto* dX = Output(0, Y.sizes(), at::dtype<float>());
@ -81,12 +63,21 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
context_.CopySameDevice<float>(Y.numel(), dYdata, dXdata); context_.CopySameDevice<float>(Y.numel(), dYdata, dXdata);
float* scaledata = scale_.mutable_data<float>(); float* scaledata = scale_.mutable_data<float>();
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
math::Dot<float, CPUContext>(D, Ydata + i * D, dYdata + i * D, math::Dot<float, CPUContext>(
scaledata + i, &context_); D, Ydata + i * D, dYdata + i * D, scaledata + i, &context_);
} }
math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans, N, D, 1, -1, math::Gemm<float, CPUContext>(
scaledata, sum_multiplier_.data<float>(), 1, CblasNoTrans,
dXdata, &context_); CblasNoTrans,
N,
D,
1,
-1,
scaledata,
sum_multiplier_.data<float>(),
1,
dXdata,
&context_);
math::Mul<float, CPUContext>(Y.numel(), dXdata, Ydata, dXdata, &context_); math::Mul<float, CPUContext>(Y.numel(), dXdata, Ydata, dXdata, &context_);
return true; return true;
} }
@ -184,7 +175,8 @@ class GetSoftmaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase; using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override { vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef( return SingleGradientDef(
def_.type() + "Gradient", "", def_.type() + "Gradient",
"",
vector<string>{O(0), GO(0)}, vector<string>{O(0), GO(0)},
vector<string>{GI(0)}); vector<string>{GI(0)});
} }
@ -192,4 +184,4 @@ class GetSoftmaxGradient : public GradientMakerBase {
REGISTER_GRADIENT(Softmax, GetSoftmaxGradient); REGISTER_GRADIENT(Softmax, GetSoftmaxGradient);
REGISTER_GRADIENT(SoftmaxFp16, GetSoftmaxGradient); REGISTER_GRADIENT(SoftmaxFp16, GetSoftmaxGradient);
} // namespace caffe2 } // namespace caffe2

View File

@ -16,6 +16,7 @@ class SoftmaxOp final : public Operator<Context> {
: Operator<Context>(std::forward<Args>(args)...), : Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {} axis_(this->template GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS; USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override; bool RunOnDevice() override;
protected: protected:

View File

@ -1,55 +0,0 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
void SoftmaxCPU(
CPUContext& context,
const int N,
const int D,
const float* Xdata,
float* Ydata,
float* scale,
const float* sum_multiplier,
bool logarithmic,
float* rowmax) {
math::RowwiseMax<float, CPUContext>(N, D, Xdata, rowmax, &context);
// Put the intermediate result X - max(X) into Y
context.template CopyFromCPU<float>(N * D, Xdata, Ydata);
// Subtract the max (for numerical reasons)
math::Gemm<float, CPUContext>(
CblasNoTrans,
CblasNoTrans,
N,
D,
1,
-1,
rowmax,
sum_multiplier,
1,
Ydata,
&context);
// Exponentiation
math::Exp<float, CPUContext>(N * D, Ydata, Ydata, &context);
math::Gemv<float, CPUContext>(
CblasNoTrans, N, D, 1, Ydata, sum_multiplier, 0, scale, &context);
// Do division
// TODO(Yangqing): maybe implement it more beautifully?
if (!logarithmic) {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] /= scale[i];
}
}
} else {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] =
Xdata[i * D + j] - rowmax[i] - log(fmaxf(scale[i], 1e-20f));
}
}
}
}
} // namespace caffe2

View File

@ -1,21 +0,0 @@
#ifndef CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
#define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
void SoftmaxCPU(
CPUContext& context,
const int N,
const int D,
const float* Xdata,
float* Ydata,
float* scale,
const float* sum_multiplier,
bool logarithmic,
float* rowmax);
} // namespace caffe2
#endif // #define CAFFE2_OPERATORS_SOFTMAX_SHARED_H_

View File

@ -0,0 +1,38 @@
#include "caffe2/operators/softmax_utils.h"
#include "caffe2/core/context.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace softmax_utils {
#define CAFFE2_SPECIALIZED_SOFTMAX_CPU(T) \
template <> \
void SoftmaxCPU<T>( \
const int N, \
const int D, \
const bool logarithmic, \
const T* X, \
T* Y, \
T* scratch, \
CPUContext* context) { \
ConstEigenArrayMap<T> X_arr(X, D, N); \
EigenArrayMap<T> Y_arr(Y, D, N); \
EigenVectorArrayMap<T> scratch_arr(scratch, N); \
scratch_arr = X_arr.colwise().maxCoeff().transpose(); \
Y_arr = X_arr.rowwise() - scratch_arr.transpose(); \
math::Exp<T, CPUContext>(N * D, Y, Y, context); \
if (logarithmic) { \
scratch_arr += Y_arr.colwise().sum().log().transpose(); \
Y_arr = X_arr.rowwise() - scratch_arr.transpose(); \
} else { \
scratch_arr = Y_arr.colwise().sum().inverse().transpose(); \
Y_arr = Y_arr.rowwise() * scratch_arr.transpose(); \
} \
}
CAFFE2_SPECIALIZED_SOFTMAX_CPU(float)
#undef CAFFE2_SPECIALIZED_SOFTMAX_CPU
} // namespace softmax_utils
} // namespace caffe2

View File

@ -0,0 +1,23 @@
#ifndef CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
#define CAFFE2_OPERATORS_SOFTMAX_UTILS_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
namespace softmax_utils {
template <typename T>
void SoftmaxCPU(
int N,
int D,
bool logarithmic,
const T* X,
T* Y,
T* scratch,
CPUContext* context);
} // namespace softmax_utils
} // namespace caffe2
#endif // CAFFE2_OPERATORS_SOFTMAX_UTILS_H_

View File

@ -1,5 +1,8 @@
#include "softmax_with_loss_op.h" #include "caffe2/operators/softmax_with_loss_op.h"
#include "softmax_shared.h"
#include <vector>
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 { namespace caffe2 {
@ -12,28 +15,28 @@ REGISTER_CPU_OPERATOR(
OPERATOR_SCHEMA(SoftmaxWithLoss) OPERATOR_SCHEMA(SoftmaxWithLoss)
.NumInputs(2, 3) .NumInputs(2, 3)
.NumOutputs(2) .NumOutputs(2)
.TensorInferenceFunction( .TensorInferenceFunction([](const OperatorDef& def,
[](const OperatorDef& def, const vector<TensorShape>& in) { const vector<TensorShape>& in) {
ArgumentHelper helper(def); ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1); auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
vector<TensorShape> out(2); vector<TensorShape> out(2);
auto logits = in[0]; // Tensor with Shape [batch_size, num_classes] auto logits = in[0]; // Tensor with Shape [batch_size, num_classes]
auto labels = in[1]; // Tensor with shape [batch_size, ] auto labels = in[1]; // Tensor with shape [batch_size, ]
const auto canonical_axis = const auto canonical_axis =
canonical_axis_index_(axis, logits.dims().size()); canonical_axis_index_(axis, logits.dims().size());
const int batch_size = const int batch_size =
size_to_dim_(canonical_axis, GetDimsVector(logits)); size_to_dim_(canonical_axis, GetDimsVector(logits));
const int num_classes = const int num_classes =
size_from_dim_(canonical_axis, GetDimsVector(logits)); size_from_dim_(canonical_axis, GetDimsVector(logits));
out[0].set_data_type(logits.data_type()); out[0].set_data_type(logits.data_type());
out[0].add_dims(batch_size); out[0].add_dims(batch_size);
out[0].add_dims(num_classes); out[0].add_dims(num_classes);
return out; return out;
}) })
.SetDoc(R"DOC( .SetDoc(R"DOC(
Combined Softmax and Cross-Entropy loss operator. The operator first computes the softmax normalized values for each layer in the batch of the given input, then computes cross-entropy loss. This operator is numerically more stable than separate `Softmax` and `CrossEntropy` ops. The inputs are a 2-D tensor `logits` of size (batch_size x input_feature_dimensions), which represents the unscaled log probabilities, and a 1-dimensional integer `labels` tensor for ground truth. An optional third input blob (`weight_tensor`) can be used to weight the samples for the loss, which is useful if the training set is unbalanced. This operator outputs a `softmax` tensor which contains the probability for each label for each example (same shape is `logits` input), and a scalar `loss` value, which is the averaged cross-entropy loss between the softmax probabilities and the ground truth values. Use parameter `label_prob`=1 to enable inputting labels as a probability distribution. Combined Softmax and Cross-Entropy loss operator. The operator first computes the softmax normalized values for each layer in the batch of the given input, then computes cross-entropy loss. This operator is numerically more stable than separate `Softmax` and `CrossEntropy` ops. The inputs are a 2-D tensor `logits` of size (batch_size x input_feature_dimensions), which represents the unscaled log probabilities, and a 1-dimensional integer `labels` tensor for ground truth. An optional third input blob (`weight_tensor`) can be used to weight the samples for the loss, which is useful if the training set is unbalanced. This operator outputs a `softmax` tensor which contains the probability for each label for each example (same shape is `logits` input), and a scalar `loss` value, which is the averaged cross-entropy loss between the softmax probabilities and the ground truth values. Use parameter `label_prob`=1 to enable inputting labels as a probability distribution.
@ -132,10 +135,18 @@ avgloss: 10.667433
</details> </details>
)DOC") )DOC")
.Arg("label_prob","*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.") .Arg(
.Arg("axis","*(type: int; default: 1)* Axis of the inputs when coerced to 2D.") "label_prob",
.Arg("scale","*(type: float)* Average loss output scaling factor (must be >= 0).") "*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
.Arg("order","*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).") .Arg(
"axis",
"*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
.Arg(
"scale",
"*(type: float)* Average loss output scaling factor (must be >= 0).")
.Arg(
"order",
"*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
.Input(0, "logits", "*(type: Tensor`<float>`)* Input tensor.") .Input(0, "logits", "*(type: Tensor`<float>`)* Input tensor.")
.Input(1, "labels", "*(type: Tensor`<float>`)* Ground truth label tensor.") .Input(1, "labels", "*(type: Tensor`<float>`)* Ground truth label tensor.")
.Input( .Input(
@ -178,36 +189,20 @@ bool SoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
} }
} }
if (!sum_multiplier_.defined()) {
sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
} else if (sum_multiplier_.numel() != D) {
sum_multiplier_.Resize(D);
math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
}
if (!losses_.defined()) { if (!losses_.defined()) {
losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU)); losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (losses_.numel() != N) { } else if (losses_.numel() != N) {
losses_.Resize(N); losses_.Resize(N);
} }
if (!rowmax_.defined()) { softmax_utils::SoftmaxCPU<float>(
rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
} else if (rowmax_.numel() != N) {
rowmax_.Resize(N);
}
SoftmaxCPU(
context_,
N, N,
D, D,
!label_prob_mode_,
X.data<float>(), X.data<float>(),
Pdata, Pdata,
losses_.mutable_data<float>(), losses_.mutable_data<float>(),
sum_multiplier_.data<float>(), &context_);
!label_prob_mode_,
rowmax_.mutable_data<float>());
// Then compute cross entropy // Then compute cross entropy
float loss_sum = 0.0; float loss_sum = 0.0;
@ -382,5 +377,5 @@ class GetSoftmaxWithLossGradient : public GradientMakerBase {
}; };
REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient); REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient);
} } // namespace
} // namespace caffe2 } // namespace caffe2

View File

@ -1,5 +1,4 @@
#include "spatial_softmax_with_loss_op.h" #include "caffe2/operators/spatial_softmax_with_loss_op.h"
#include "softmax_shared.h"
namespace caffe2 { namespace caffe2 {

View File

@ -14,8 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include "group_spatial_softmax_op.h" #include "modules/detectron/group_spatial_softmax_op.h"
#include "caffe2/operators/softmax_shared.h"
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 { namespace caffe2 {
@ -59,18 +60,12 @@ See: https://arxiv.org/abs/1708.02002 for details.
OPERATOR_SCHEMA(GroupSpatialSoftmaxGradient) OPERATOR_SCHEMA(GroupSpatialSoftmaxGradient)
.NumInputs(2) .NumInputs(2)
.NumOutputs(1) .NumOutputs(1)
.Input( .Input(0, "scores", "See GroupSpatialSoftmax")
0,
"scores",
"See GroupSpatialSoftmax")
.Input( .Input(
1, 1,
"d_probabilities", "d_probabilities",
"Gradient of forward output 0 (probabilities).") "Gradient of forward output 0 (probabilities).")
.Output( .Output(0, "d_scores", "Gradient of forward input 0 (scores).");
0,
"d_scores",
"Gradient of forward input 0 (scores).");
class GetGroupSpatialSoftmaxGradient : public GradientMakerBase { class GetGroupSpatialSoftmaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase; using GradientMakerBase::GradientMakerBase;
@ -84,4 +79,5 @@ class GetGroupSpatialSoftmaxGradient : public GradientMakerBase {
}; };
REGISTER_GRADIENT(GroupSpatialSoftmax, GetGroupSpatialSoftmaxGradient); REGISTER_GRADIENT(GroupSpatialSoftmax, GetGroupSpatialSoftmaxGradient);
} // namespace caffe2 } // namespace caffe2

View File

@ -14,8 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include "softmax_focal_loss_op.h" #include "modules/detectron/softmax_focal_loss_op.h"
#include "caffe2/operators/softmax_shared.h"
#include "caffe2/operators/softmax_utils.h"
namespace caffe2 { namespace caffe2 {
@ -46,12 +47,8 @@ See: https://arxiv.org/abs/1708.02002 for details.
.Arg( .Arg(
"scale", "scale",
"(float) default 1.0; multiply the loss by this scale factor.") "(float) default 1.0; multiply the loss by this scale factor.")
.Arg( .Arg("alpha", "(float) default 0.25; Focal Loss's alpha hyper-parameter.")
"alpha", .Arg("gamma", "(float) default 1.0; Focal Loss's gamma hyper-parameter.")
"(float) default 0.25; Focal Loss's alpha hyper-parameter.")
.Arg(
"gamma",
"(float) default 1.0; Focal Loss's gamma hyper-parameter.")
.Arg( .Arg(
"num_classes", "num_classes",
"(int) default 81; number of classes in each softmax group.") "(int) default 81; number of classes in each softmax group.")
@ -69,12 +66,8 @@ See: https://arxiv.org/abs/1708.02002 for details.
.Input( .Input(
2, 2,
"normalizer", "normalizer",
"Scalar; the loss is normalized by 1 / max(1, normalizer)." "Scalar; the loss is normalized by 1 / max(1, normalizer).")
) .Output(0, "loss", "Scalar loss.")
.Output(
0,
"loss",
"Scalar loss.")
.Output( .Output(
1, 1,
"probabilities", "probabilities",
@ -85,30 +78,15 @@ See: https://arxiv.org/abs/1708.02002 for details.
OPERATOR_SCHEMA(SoftmaxFocalLossGradient) OPERATOR_SCHEMA(SoftmaxFocalLossGradient)
.NumInputs(5) .NumInputs(5)
.NumOutputs(1) .NumOutputs(1)
.Input( .Input(0, "scores", "See SoftmaxFocalLoss.")
0, .Input(1, "labels", "See SoftmaxFocalLoss.")
"scores", .Input(2, "normalizer", "See SoftmaxFocalLoss.")
"See SoftmaxFocalLoss.")
.Input(
1,
"labels",
"See SoftmaxFocalLoss.")
.Input(
2,
"normalizer",
"See SoftmaxFocalLoss.")
.Input( .Input(
3, 3,
"probabilities", "probabilities",
"Output 1 from SoftmaxFocalLoss; See SoftmaxFocalLoss.") "Output 1 from SoftmaxFocalLoss; See SoftmaxFocalLoss.")
.Input( .Input(4, "d_loss", "Gradient of forward output 0 (loss)")
4, .Output(0, "d_scores", "Gradient of forward input 0 (scores)");
"d_loss",
"Gradient of forward output 0 (loss)")
.Output(
0,
"d_scores",
"Gradient of forward input 0 (scores)");
class GetSoftmaxFocalLossGradient : public GradientMakerBase { class GetSoftmaxFocalLossGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase; using GradientMakerBase::GradientMakerBase;
@ -122,4 +100,5 @@ class GetSoftmaxFocalLossGradient : public GradientMakerBase {
}; };
REGISTER_GRADIENT(SoftmaxFocalLoss, GetSoftmaxFocalLossGradient); REGISTER_GRADIENT(SoftmaxFocalLoss, GetSoftmaxFocalLossGradient);
} // namespace caffe2 } // namespace caffe2