Files
pytorch/test/cpp_extensions/jit_extension.cpp
2018-02-01 22:42:07 -08:00

12 lines
209 B
C++

#include <torch/torch.h>
using namespace at;
Tensor tanh_add(Tensor x, Tensor y) {
return x.tanh() + y.tanh();
}
PYBIND11_MODULE(jit_extension, m) {
m.def("tanh_add", &tanh_add, "tanh(x) + tanh(y)");
}