[caffe2] Reintroduce Log1p operator (#55073)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55073

Original diff D27422219 (d92e2520de) was reverted, reintroducing this op again.

Reviewed By: ChunliF

Differential Revision: D27473735

fbshipit-source-id: 1af0281724e9ada699ebf2045d51f65083daf5b4
This commit is contained in:
Oleg Khabinov
2021-03-31 22:27:59 -07:00
committed by Facebook GitHub Bot
parent 547346d663
commit 6145ac07b5
4 changed files with 196 additions and 0 deletions

View File

@ -0,0 +1,74 @@
#include "caffe2/operators/log1p_op.h"
#include "caffe2/utils/eigen_utils.h"
#include <algorithm>
#include <functional>
namespace caffe2 {
template <>
template <typename T>
bool Log1pGradientFunctor<CPUContext>::Forward(
const std::vector<int>& X_dims,
const std::vector<int>& /* dY_dims */,
const T* X,
const T* dY,
T* dX,
CPUContext* /* context */) const {
const int size = std::accumulate(
X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
ConstEigenVectorArrayMap<T> dY_arr(dY, size);
ConstEigenVectorArrayMap<T> X_arr(X, size);
EigenVectorMap<T>(dX, size) = dY_arr / (T(1) + X_arr);
return true;
}
REGISTER_CPU_OPERATOR(
Log1p,
UnaryElementwiseOp<TensorTypes<float>, CPUContext, Log1pFunctor<CPUContext>>);
REGISTER_CPU_OPERATOR(
Log1pGradient,
BinaryElementwiseOp<
TensorTypes<float>,
CPUContext,
Log1pGradientFunctor<CPUContext>>);
OPERATOR_SCHEMA(Log1p)
.NumInputs(1)
.NumOutputs(1)
.IdenticalTypeAndShape()
.SetDoc(R"DOC(
Calculates Log1p of the given input tensor element-wise. This
operation can be done in an in-place fashion too, by providing the same input
and output blobs.
Github Link:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/log1p_op.cc
)DOC")
.Input(0, "input", "Input data blob to be operated on.")
.Output(0, "output", "Output data blob with same shape as input")
.InheritOnnxSchema();
OPERATOR_SCHEMA(Log1pGradient)
.NumInputs(2)
.NumOutputs(1)
.IdenticalTypeAndShapeOfInput(0);
namespace {
class GetLog1pGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"Log1pGradient",
"",
std::vector<std::string>{I(0), GO(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(Log1p, GetLog1pGradient);
} // namespace caffe2

View File

@ -0,0 +1,60 @@
#include "caffe2/operators/log1p_op.h"
#include <algorithm>
#include <functional>
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <typename T>
__global__ void
Log1pGradientCUDAKernel(const int N, const T* dY, const T* X, T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
dX[i] = __ldg(dY + i) / (__ldg(X + i) + T(1));
#else
dX[i] = dY[i] / (X[i] + T(1));
#endif
}
}
} // namespace
template <>
template <typename T>
bool Log1pGradientFunctor<CUDAContext>::Forward(
const std::vector<int>& X_dims,
const std::vector<int>& /* dY_dims */,
const T* X,
const T* dY,
T* dX,
CUDAContext* context) const {
const int size = std::accumulate(
X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
Log1pGradientCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(size, dY, X, dX);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
REGISTER_CUDA_OPERATOR(
Log1p,
UnaryElementwiseOp<
TensorTypes<float>,
CUDAContext,
Log1pFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
Log1pGradient,
BinaryElementwiseOp<
TensorTypes<float>,
CUDAContext,
Log1pGradientFunctor<CUDAContext>>);
} // namespace caffe2

View File

@ -0,0 +1,34 @@
#ifndef CAFFE2_OPERATORS_LOG1P_OP_H_
#define CAFFE2_OPERATORS_LOG1P_OP_H_
#include <vector>
#include "caffe2/operators/elementwise_ops.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <class Context>
struct Log1pFunctor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* context) const {
math::Log1p(N, X, Y, context);
return true;
}
};
template <class Context>
struct Log1pGradientFunctor {
template <typename T>
bool Forward(
const std::vector<int>& X_dims,
const std::vector<int>& dY_dims,
const T* X,
const T* dY,
T* dX,
Context* context) const;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_LOG1P_OP_H_

View File

@ -760,6 +760,34 @@ class TestElementwiseOps(hu.HypothesisTestCase):
)
self.assertDeviceChecks(dc, op, [X], [0])
@given(X=hu.tensor(dtype=np.float32), **hu.gcs)
@settings(deadline=10000)
def test_log1p(self, X, gc, dc):
op = core.CreateOperator(
"Log1p",
["X"],
["Y"]
)
def ref_log1p(input):
result = np.log1p(input)
return (result,)
def ref_log1p_grad(g_out, outputs, fwd_inputs):
result = g_out / (fwd_inputs[0] + 1)
return (result,)
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[X],
reference=ref_log1p,
output_to_grad="Y",
grad_reference=ref_log1p_grad,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
if __name__ == "__main__":
unittest.main()