Files
pytorch/caffe2/operators/minmax_gradient_ops.cc
Nikita Shulga a9b0a921d5 Disable avoid-non-const-global-variables lint check (#62008)
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`

All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`;  do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008

Reviewed By: driazati, r-barnes

Differential Revision: D29838584

Pulled By: malfet

fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
2021-07-22 18:04:40 -07:00

67 lines
2.1 KiB
C++

#include "caffe2/operators/minmax_ops.h"
#include <string>
#include <vector>
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
template <typename T, class Context>
bool SelectGradientOpBase<T, Context>::RunOnDevice() {
const auto& Y = Input(0);
const auto& dY = Input(1);
const int N = Y.numel();
ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N);
ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N);
for (int i = 0; i < OutputSize(); i++) {
const auto& Xi = Input(i + 2);
auto* dXi = Output(i, Xi.sizes(), at::dtype<T>());
ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N);
EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N);
dXi_arr = (Xi_arr == Y_arr).template cast<T>() * dY_arr;
}
return true;
}
REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
namespace {
class GetMaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
std::vector<std::string> inputs = {O(0), GO(0)};
std::vector<std::string> grad_inputs;
for (int i = 0; i < def_.input_size(); ++i) {
inputs.push_back(I(i));
grad_inputs.push_back(GI(i));
}
return SingleGradientDef("MaxGradient", "", inputs, grad_inputs);
}
};
class GetMinGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
std::vector<std::string> inputs = {O(0), GO(0)};
std::vector<std::string> grad_inputs;
for (int i = 0; i < def_.input_size(); ++i) {
inputs.push_back(I(i));
grad_inputs.push_back(GI(i));
}
return SingleGradientDef("MinGradient", "", inputs, grad_inputs);
}
};
} // namespace
REGISTER_GRADIENT(Max, GetMaxGradient);
REGISTER_GRADIENT(Min, GetMinGradient);
} // namespace caffe2