mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
tanh
This commit is contained in:
@ -26,6 +26,7 @@ cc_library(
|
||||
"relu_op.cc",
|
||||
"softmax_op.cc",
|
||||
"summarize_op.cc",
|
||||
"tanh_op.cc",
|
||||
"tensor_protos_db_input.cc",
|
||||
"utility_ops.cc",
|
||||
],
|
||||
@ -56,6 +57,7 @@ cuda_library(
|
||||
"relu_op.cu",
|
||||
"softmax_op.cu",
|
||||
"summarize_op.cu",
|
||||
"tanh_op.cu",
|
||||
],
|
||||
deps = [
|
||||
":operators_headers",
|
||||
|
56
caffe2/operators/tanh_op.cc
Normal file
56
caffe2/operators/tanh_op.cc
Normal file
@ -0,0 +1,56 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "caffe2/operators/elementwise_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T>
|
||||
struct TanhCPUFunctor {
|
||||
inline void operator()(const int n, const T* x,
|
||||
T* y, CPUContext* device_context) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = tanh(x[i]);
|
||||
}
|
||||
}
|
||||
inline bool InplaceAllowed() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TanhGradientCPUFunctor {
|
||||
inline void operator()(const int n, const T* y, const T* dy,
|
||||
T* dx, CPUContext* device_context) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
dx[i] = dy[i] * (1 - y[i] * y[i]);
|
||||
}
|
||||
}
|
||||
inline bool InplaceAllowed(const int input_id, const int output_id) {
|
||||
if (input_id == 1 && output_id == 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
REGISTER_CPU_OPERATOR(
|
||||
Tanh, UnaryElementwiseOp<float, CPUContext, TanhCPUFunctor<float> >);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
TanhGradient, BinaryElementwiseOp<float, CPUContext,
|
||||
TanhGradientCPUFunctor<float> >);
|
||||
|
||||
struct GetTanhGradient : public GetGradientDefBase {
|
||||
static vector<OperatorDef>* Create(const OperatorDef& def) {
|
||||
return new vector<OperatorDef>{
|
||||
CreateOperatorDef(
|
||||
"TanhGradient", "",
|
||||
std::vector<string>{def.output(0),
|
||||
GradientName(def.output(0))},
|
||||
std::vector<string>{GradientName(def.input(0))})};
|
||||
}
|
||||
};
|
||||
REGISTER_GRADIENT(Tanh, GetTanhGradient);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
60
caffe2/operators/tanh_op.cu
Normal file
60
caffe2/operators/tanh_op.cu
Normal file
@ -0,0 +1,60 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/elementwise_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T>
|
||||
__global__ void TanhKernel(const int N, const T* X, T* Y) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
Y[i] = tanh(X[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TanhGradientKernel(const int N, const T* Y, const T* dY,
|
||||
T* dX) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
dX[i] = dY[i]*(1 - Y[i]*Y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TanhCUDAFunctor {
|
||||
inline void operator()(const int n, const float* x,
|
||||
float* y, CUDAContext* device_context) {
|
||||
TanhKernel<T><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
|
||||
0, device_context->cuda_stream()>>>(n, x, y);
|
||||
return;
|
||||
}
|
||||
inline bool InplaceAllowed() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TanhGradientCUDAFunctor {
|
||||
inline void operator()(const int n, const T* y, const T* dy,
|
||||
T* dx, CUDAContext* device_context) {
|
||||
TanhGradientKernel<T><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
|
||||
0, device_context->cuda_stream()>>>(n, y, dy, dx);
|
||||
return;
|
||||
}
|
||||
inline bool InplaceAllowed(const int input_id, const int output_id) {
|
||||
if (input_id == 1 && output_id == 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
Tanh, UnaryElementwiseOp<float, CUDAContext, TanhCUDAFunctor<float> >);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
TanhGradient, BinaryElementwiseOp<float, CUDAContext,
|
||||
TanhGradientCUDAFunctor<float> >);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
@ -27,7 +27,6 @@ else:
|
||||
]
|
||||
|
||||
|
||||
|
||||
class TestConvLegacyPooling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
@ -213,5 +212,24 @@ class TestRelu(unittest.TestCase):
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
class TestTanh(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_configs = [
|
||||
(1, 1),
|
||||
(2, 1),
|
||||
(1, 2, 3, 4),
|
||||
]
|
||||
|
||||
def testTanh(self):
|
||||
for input_size in self.test_configs:
|
||||
op = core.CreateOperator("Tanh")(["X"], ["Y"])
|
||||
X = np.random.rand(*input_size).astype(np.float32) - 0.5
|
||||
res = device_checker.CheckSimple(op, [X], [0])
|
||||
self.assertTrue(res)
|
||||
for checker in gradient_checkers:
|
||||
res, grad, grad_estimated = checker.CheckSimple(op, [X], 0, [0])
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Reference in New Issue
Block a user