Use TORCH_EXTENSION_NAME macro to avoid mismatched module/extension name (#5277)

* Warn users about mismatched module/extension name

* Define TORCH_EXTENSION_NAME macro
This commit is contained in:
Peter Goldsborough
2018-02-16 19:31:04 -08:00
committed by Soumith Chintala
parent 5c93ca258b
commit 22fe542b8e
8 changed files with 36 additions and 18 deletions

View File

@ -1,11 +1,20 @@
#include <torch/torch.h>
#include "doubler.h"
using namespace at;
Tensor exp_add(Tensor x, Tensor y);
Tensor tanh_add(Tensor x, Tensor y) {
return x.tanh() + y.tanh();
}
PYBIND11_MODULE(jit_extension, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("tanh_add", &tanh_add, "tanh(x) + tanh(y)");
m.def("exp_add", &exp_add, "exp(x) + exp(y)");
py::class_<Doubler>(m, "Doubler")
.def(py::init<int, int>())
.def("forward", &Doubler::forward)
.def("get", &Doubler::get);
}