mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
99 lines
2.9 KiB
C++
99 lines
2.9 KiB
C++
#include "caffe2/operators/jsd_op.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace {
|
|
|
|
static constexpr float kLOG_THRESHOLD() {
|
|
return 1e-20;
|
|
}
|
|
|
|
inline float logit(float p) {
|
|
// it computes log(p / (1-p))
|
|
// to avoid numeric issue, hard code p log(p) when p approaches 0
|
|
float x = std::min(std::max(p, kLOG_THRESHOLD()), 1 - kLOG_THRESHOLD());
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
return -log(1. / x - 1.);
|
|
}
|
|
|
|
inline float entropy(float p) {
|
|
if (p < kLOG_THRESHOLD() || 1 - p < kLOG_THRESHOLD()) {
|
|
return 0.;
|
|
} else {
|
|
float q = 1 - p;
|
|
return -p * log(p) - q * log(q);
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
template <>
|
|
bool BernoulliJSDOp<float, CPUContext>::RunOnDevice() {
|
|
auto& X = Input(0); // predicted probabilities
|
|
auto& T = Input(1); // target probabilities
|
|
int N = X.numel();
|
|
CAFFE_ENFORCE_EQ(T.numel(), N);
|
|
auto* L = Output(0, X.sizes(), at::dtype<float>()); // JSD loss output
|
|
auto* x_data = X.data<float>();
|
|
auto* t_data = T.data<float>();
|
|
auto* l_data = L->template mutable_data<float>();
|
|
for (int i = 0; i < N; i++) {
|
|
auto p_mdl = x_data[i];
|
|
auto p_emp = t_data[i];
|
|
auto p_avg = (p_mdl + p_emp) / 2.;
|
|
auto jsd = entropy(p_avg) - (entropy(p_mdl) + entropy(p_emp)) / 2.;
|
|
l_data[i] = jsd;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <>
|
|
bool BernoulliJSDGradientOp<float, CPUContext>::RunOnDevice() {
|
|
auto& go = Input(0);
|
|
auto& X = Input(1);
|
|
auto& T = Input(2);
|
|
|
|
int N = X.numel();
|
|
auto* gi = Output(0, X.sizes(), at::dtype<float>());
|
|
auto* go_data = go.data<float>();
|
|
auto* x_data = X.data<float>();
|
|
auto* t_data = T.data<float>();
|
|
auto* gi_data = gi->template mutable_data<float>();
|
|
for (int i = 0; i < N; i++) {
|
|
auto p_mdl = x_data[i];
|
|
auto p_emp = t_data[i];
|
|
auto p_avg = (p_mdl + p_emp) / 2.;
|
|
auto g_jsd = (logit(p_mdl) - logit(p_avg)) / 2.;
|
|
gi_data[i] = go_data[i] * g_jsd;
|
|
}
|
|
return true;
|
|
}
|
|
REGISTER_CPU_OPERATOR(BernoulliJSD, BernoulliJSDOp<float, CPUContext>);
|
|
REGISTER_CPU_OPERATOR(
|
|
BernoulliJSDGradient,
|
|
BernoulliJSDGradientOp<float, CPUContext>);
|
|
OPERATOR_SCHEMA(BernoulliJSD)
|
|
.NumInputs(2)
|
|
.NumOutputs(1)
|
|
.SetDoc(R"DOC(
|
|
Computes the Jensen-Shannon divergence (JSD) between two Bernoulli distributions
|
|
where each is parametrized by a single probability.
|
|
)DOC")
|
|
.Input(0, "X", "array of probabilities for prediction")
|
|
.Input(0, "T", "array of probabilities for target")
|
|
.Output(0, "L", "array of JSD losses");
|
|
OPERATOR_SCHEMA(BernoulliJSDGradient).NumInputs(3).NumOutputs(1);
|
|
|
|
class GetBernoulliJSDGradient : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
vector<OperatorDef> GetGradientDefs() override {
|
|
return SingleGradientDef(
|
|
"BernoulliJSDGradient",
|
|
"",
|
|
vector<string>{GO(0), I(0), I(1)},
|
|
vector<string>{GI(0)});
|
|
}
|
|
};
|
|
REGISTER_GRADIENT(BernoulliJSD, GetBernoulliJSDGradient);
|
|
|
|
} // namespace caffe2
|