mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
12 lines
209 B
C++
12 lines
209 B
C++
#include <torch/torch.h>
|
|
|
|
using namespace at;
|
|
|
|
Tensor tanh_add(Tensor x, Tensor y) {
|
|
return x.tanh() + y.tanh();
|
|
}
|
|
|
|
PYBIND11_MODULE(jit_extension, m) {
|
|
m.def("tanh_add", &tanh_add, "tanh(x) + tanh(y)");
|
|
}
|